diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29e5254d33..105c42a503 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,7 +39,8 @@ jobs: - name: Install optional test requirements run: | - python -m pip install fairscale iopath transformers + python -m pip install iopath transformers pyarrow + python -m pip install git+https://github.com/facebookresearch/fairscale.git@master - name: Lint with flake8 run: | diff --git a/README.md b/README.md index 5fedac7eec..839dd8e1de 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,9 @@ We provide reference implementations of various sequence modeling papers: ### What's New: +* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) +* February 2021 [Added LASER training code](examples/laser/README.md) +* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) * December 2020: [GottBERT model and code released](examples/gottbert/README.md) * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) @@ -68,14 +71,14 @@ We provide reference implementations of various sequence modeling papers: * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) * October 2020: [Added CRISS models and code](examples/criss/README.md) + +
Previous updates

+ * September 2020: [Added Linformer code](examples/linformer/README.md) * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - -

Previous updates

- * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) * April 2020: [Quant-Noise code released](examples/quant_noise/README.md) @@ -108,6 +111,8 @@ We provide reference implementations of various sequence modeling papers: * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration +* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) +* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) with a convenient `torch.hub` interface: diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 04c797fe50..6a15298382 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -120,7 +120,7 @@ class LanguageModelingConfig(FairseqDataclass): ... @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(LegacyFairseqTask): +class LanguageModelingTask(FairseqTask): ... @classmethod def setup_task(cls, cfg: LanguageModelingConfig): diff --git a/examples/bart/README.md b/examples/bart/README.md index 013a809be6..4050a724ee 100644 --- a/examples/bart/README.md +++ b/examples/bart/README.md @@ -179,38 +179,23 @@ with open('glue_data/MNLI/dev_matched.tsv') as fin: ``` #### Evaluating the `bart.large.cnn` model: -Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. +- For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores +- `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search. + In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`. -```python -bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn') -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('test.source') as source, open('test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() -``` - -Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). +In `fairseq`, summaries can be generated using: + +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir pytorch/fairseq \ + --model-file bart.large.cnn \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` + +For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). ```bash export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar diff --git a/examples/bart/README.summarization.md b/examples/bart/README.summarization.md index d7fecc9ce6..8727584f2b 100644 --- a/examples/bart/README.summarization.md +++ b/examples/bart/README.summarization.md @@ -80,42 +80,23 @@ Expected training time is about `5 hours`. Training time can be reduced with dis Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task ### Inference for CNN-DM test data using above trained checkpoint. -After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: +After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example -```python -import torch -from fairseq.models.bart import BARTModel - -bart = BARTModel.from_pretrained( - 'checkpoints/', - checkpoint_file='checkpoint_best.pt', - data_name_or_path='cnn_dm-bin' -) - -bart.cuda() -bart.eval() -bart.half() -count = 1 -bsz = 32 -with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: - sline = source.readline().strip() - slines = [sline] - for sline in source: - if count % bsz == 0: - with torch.no_grad(): - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() - slines = [] - - slines.append(sline.strip()) - count += 1 - if slines != []: - hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) - for hypothesis in hypotheses_batch: - fout.write(hypothesis + '\n') - fout.flush() +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo +``` +For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10: +```bash +cp data-bin/cnn_dm/dict.source.txt checkpoints/ +python examples/bart/summarize.py \ + --model-dir checkpoints \ + --model-file checkpoint_best.pt \ + --src cnn_dm/test.source \ + --out cnn_dm/test.hypo \ + --xsum-kwargs ``` -Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation diff --git a/examples/bart/summarize.py b/examples/bart/summarize.py new file mode 100644 index 0000000000..04435f80e3 --- /dev/null +++ b/examples/bart/summarize.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.models.bart import BARTModel +import argparse + +XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3) +CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) + + +@torch.no_grad() +def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs): + count = 1 + + # if n_obs is not None: bsz = min(bsz, n_obs) + + with open(infile) as source, open(outfile, "w") as fout: + sline = source.readline().strip() + slines = [sline] + for sline in source: + if n_obs is not None and count > n_obs: + break + if count % bsz == 0: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + slines = [] + + slines.append(sline.strip()) + count += 1 + + if slines != []: + hypotheses_batch = bart.sample(slines, **eval_kwargs) + for hypothesis in hypotheses_batch: + fout.write(hypothesis + "\n") + fout.flush() + + +def main(): + """ + Usage:: + + python examples/bart/summarize.py \ + --model-dir $HOME/bart.large.cnn \ + --model-file model.pt \ + --src $HOME/data-bin/cnn_dm/test.source + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + required=True, + type=str, + default="bart.large.cnn/", + help="path containing model file and src_dict.txt", + ) + parser.add_argument( + "--model-file", + default="checkpoint_best.pt", + help="where in model_dir are weights saved", + ) + parser.add_argument( + "--src", default="test.source", help="text to summarize", type=str + ) + parser.add_argument( + "--out", default="test.hypo", help="where to save summaries", type=str + ) + parser.add_argument("--bsz", default=32, help="where to save summaries", type=int) + parser.add_argument( + "--n", default=None, help="how many examples to summarize", type=int + ) + parser.add_argument( + "--xsum-kwargs", + action="store_true", + default=False, + help="if true use XSUM_KWARGS else CNN_KWARGS", + ) + args = parser.parse_args() + eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS + if args.model_dir == "pytorch/fairseq": + bart = torch.hub.load("pytorch/fairseq", args.model_file) + else: + bart = BARTModel.from_pretrained( + args.model_dir, + checkpoint_file=args.model_file, + data_name_or_path=args.model_dir, + ) + bart = bart.eval() + if torch.cuda.is_available(): + bart = bart.cuda().half() + generate( + bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs + ) + + +if __name__ == "__main__": + main() diff --git a/examples/fully_sharded_data_parallel/README.md b/examples/fully_sharded_data_parallel/README.md new file mode 100644 index 0000000000..d620f0e4f1 --- /dev/null +++ b/examples/fully_sharded_data_parallel/README.md @@ -0,0 +1,177 @@ +# Fully Sharded Data Parallel (FSDP) + +## Overview +Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and +[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel +training can be made significantly more efficient by sharding the model +parameters and optimizer state across data parallel workers. These ideas are +encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided +by [fairscale](https://github.com/facebookresearch/fairscale/). + +Compared to PyTorch DDP: +* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training) +* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs +* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass +* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs + +FSDP is fully supported in fairseq via the following new arguments: +* `--ddp-backend=fully_sharded`: enables full sharding via FSDP +* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`) +* `--no-reshard-after-forward`: increases training speed for large models (1B+ params) and is similar to ZeRO stage 2 +* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal + +

Limitations

+ +FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP): +* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.) +* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of these and other limitations. + +

+ +
How it works

+ +Fully Sharded Data Parallel + +See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed +explanation of how FSDP works. + +

+ +## Example usage + +The following examples illustrate how to train a very large language model with +13 billion parameters on 1 GPU by offloading parameters and optimizer states to +CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs. + +These examples use the WikiText-103 dataset for demonstration purposes, but +in practice a much larger dataset will be needed to achieve good results. +Follow the [instructions here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data) +to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary. + +### 13B params on 1 V100 GPU (with CPU offloading) + +The following command trains a 13B parameter GPT-3 model on a single V100 GPU +using the `--cpu-offload` feature to offload parameters and optimizer states to +CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the +`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)), +which further saves memory in exchange for a small increase in computation. + +**Requirements:** +- Install the latest master version of fairscale: `pip install git+https://github.com/facebookresearch/fairscale.git@master` +- You'll need 32GB of GPU memory and ~256GB of system memory to train the 13B param model. +- If you have less system memory, the 6.7B param model can be trained with ~128GB of system memory, just set `--arch transformer_lm_gpt3_6_7` +- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command. + +**Notes:** +- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow. +- The `--cpu-offload` feature requires training in mixed precision (`--fp16`). +- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading. +- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`). + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 +``` + +
Example output

+ +``` +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs) +2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"} +2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"} +2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"} +2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"} +2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"} +2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"} +2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"} +2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"} +2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"} +2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"} +2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"} +2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds +``` + +

+ +### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding) + +FSDP can also shard the parameters and optimizer states across multiple GPUs, +reducing memory requirements significantly. On 8 x 32GB GPUs, sharding enables +training the same 13B parameter model *without offloading the parameters to +CPU*. However, without CPU offloading we'd only be able to fit a batch size of +1 per GPU, which would cause training speed to suffer. + +We obtain the best performance on 8 GPUs by combining full sharding and CPU +offloading. The following command trains the same 13B parameter GPT-3 model as +before on 8 x 32GB V100 GPUs; training speed increases superlinearly from ~310 +words per second to ~3200 words per second. + +```bash +OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + fairseq-train data-bin/wikitext-103-roberta-bpe-bin \ + --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \ + --cpu-offload --checkpoint-activations \ + --task language_modeling --tokens-per-sample 2048 --batch-size 8 \ + --arch transformer_lm_gpt3_13 \ + --optimizer cpu_adam --adam-betas "(0.9,0.98)" \ + --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \ + --max-update 10 --no-save --log-format json --log-interval 1 +``` + +
Example output

+ +``` +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920) +(...) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs) +2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8 +(...) +Adam Optimizer #0 is created with AVX2 arithmetic capability. +Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1 +(...) +2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"} +2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"} +2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0 +2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0 +2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"} +2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"} +2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"} +2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"} +2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"} +2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"} +2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"} +2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"} +2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10 +2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset +2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below) +2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"} +2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds +``` + +

diff --git a/examples/laser/README.md b/examples/laser/README.md new file mode 100644 index 0000000000..66acada04f --- /dev/null +++ b/examples/laser/README.md @@ -0,0 +1,144 @@ +# LASER Language-Agnostic SEntence Representations + +LASER is a library to calculate and use multilingual sentence embeddings. + +You can find more information about LASER and how to use it on the official [LASER repository](https://github.com/facebookresearch/LASER). + +This folder contains source code for training LASER embeddings. + + +## Prepare data and configuration file + +Binarize your data with fairseq, as described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing). + +Create a json config file with this format: +``` +{ + "src_vocab": "/path/to/spm.src.cvocab", + "tgt_vocab": "/path/to/spm.tgt.cvocab", + "train": [ + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang1-tgtlang0/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang1-tgtlang1/train.srclang1", + "tgt": "/path/to/srclang1-tgtlang1/train.tgtlang1" + }, + { + "type": "translation", + "id": 0, + "src": "/path/to/srclang2-tgtlang0/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang0/train.tgtlang0" + }, + { + "type": "translation", + "id": 1, + "src": "/path/to/srclang2-tgtlang1/train.srclang2", + "tgt": "/path/to/srclang2-tgtlang1/train.tgtlang1" + }, + ... + ], + "valid": [ + { + "type": "translation", + "id": 0, + "src": "/unused", + "tgt": "/unused" + } + ] +} +``` +where paths are paths to binarized indexed fairseq dataset files. +`id` represents the target language id. + + +## Training Command Line Example + +``` +fairseq-train \ + /path/to/configfile_described_above.json \ + --user-dir examples/laser/laser_src \ + --log-interval 100 --log-format simple \ + --task laser --arch laser_lstm \ + --save-dir . \ + --optimizer adam \ + --lr 0.001 \ + --lr-scheduler inverse_sqrt \ + --clip-norm 5 \ + --warmup-updates 90000 \ + --update-freq 2 \ + --dropout 0.0 \ + --encoder-dropout-out 0.1 \ + --max-tokens 2000 \ + --max-epoch 50 \ + --encoder-bidirectional \ + --encoder-layers 5 \ + --encoder-hidden-size 512 \ + --decoder-layers 1 \ + --decoder-hidden-size 2048 \ + --encoder-embed-dim 320 \ + --decoder-embed-dim 320 \ + --decoder-lang-embed-dim 32 \ + --warmup-init-lr 0.001 \ + --disable-validation +``` + + +## Applications + +We showcase several applications of multilingual sentence embeddings +with code to reproduce our results (in the directory "tasks"). + +* [**Cross-lingual document classification**](https://github.com/facebookresearch/LASER/tree/master/tasks/mldoc) using the + [*MLDoc*](https://github.com/facebookresearch/MLDoc) corpus [2,6] +* [**WikiMatrix**](https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix) + Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia [7] +* [**Bitext mining**](https://github.com/facebookresearch/LASER/tree/master/tasks/bucc) using the + [*BUCC*](https://comparable.limsi.fr/bucc2018/bucc2018-task.html) corpus [3,5] +* [**Cross-lingual NLI**](https://github.com/facebookresearch/LASER/tree/master/tasks/xnli) + using the [*XNLI*](https://www.nyu.edu/projects/bowman/xnli/) corpus [4,5,6] +* [**Multilingual similarity search**](https://github.com/facebookresearch/LASER/tree/master/tasks/similarity) [1,6] +* [**Sentence embedding of text files**](https://github.com/facebookresearch/LASER/tree/master/tasks/embed) + example how to calculate sentence embeddings for arbitrary text files in any of the supported language. + +**For all tasks, we use exactly the same multilingual encoder, without any task specific optimization or fine-tuning.** + + + +## References + +[1] Holger Schwenk and Matthijs Douze, + [*Learning Joint Multilingual Sentence Representations with Neural Machine Translation*](https://aclanthology.info/papers/W17-2619/w17-2619), + ACL workshop on Representation Learning for NLP, 2017 + +[2] Holger Schwenk and Xian Li, + [*A Corpus for Multilingual Document Classification in Eight Languages*](http://www.lrec-conf.org/proceedings/lrec2018/pdf/658.pdf), + LREC, pages 3548-3551, 2018. + +[3] Holger Schwenk, + [*Filtering and Mining Parallel Data in a Joint Multilingual Space*](http://aclweb.org/anthology/P18-2037) + ACL, July 2018 + +[4] Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk and Veselin Stoyanov, + [*XNLI: Cross-lingual Sentence Understanding through Inference*](https://aclweb.org/anthology/D18-1269), + EMNLP, 2018. + +[5] Mikel Artetxe and Holger Schwenk, + [*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136) + arXiv, Nov 3 2018. + +[6] Mikel Artetxe and Holger Schwenk, + [*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464) + arXiv, Dec 26 2018. + +[7] Holger Schwenk, Vishrav Chaudhary, Shuo Sun, Hongyu Gong and Paco Guzman, + [*WikiMatrix: Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia*](https://arxiv.org/abs/1907.05791) + arXiv, July 11 2019. + +[8] Holger Schwenk, Guillaume Wenzek, Sergey Edunov, Edouard Grave and Armand Joulin + [*CCMatrix: Mining Billions of High-Quality Parallel Sentences on the WEB*](https://arxiv.org/abs/1911.04944) diff --git a/examples/laser/laser_src/__init__.py b/examples/laser/laser_src/__init__.py new file mode 100644 index 0000000000..9ffbd656d8 --- /dev/null +++ b/examples/laser/laser_src/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .laser_task import * # noqa +from .laser_lstm import * # noqa +from .laser_transformer import * # noqa diff --git a/examples/laser/laser_src/laser_lstm.py b/examples/laser/laser_src/laser_lstm.py new file mode 100644 index 0000000000..10df90e002 --- /dev/null +++ b/examples/laser/laser_src/laser_lstm.py @@ -0,0 +1,585 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils + +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) + + +@register_model("laser_lstm") +class LSTMModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=None, + dataset_name="", + ): + assert target_language_id is not None + + src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name) + return self.decoder( + prev_output_tokens, src_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained encoder embedding", + ) + parser.add_argument( + "--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size" + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="number of encoder layers" + ) + parser.add_argument( + "--encoder-bidirectional", + action="store_true", + help="make all layers of encoder bidirectional", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-embed-path", + default=None, + type=str, + metavar="STR", + help="path to pre-trained decoder embedding", + ) + parser.add_argument( + "--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size" + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="number of decoder layers" + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--decoder-zero-init", + type=str, + metavar="BOOL", + help="initialize the decoder hidden/cell state to zero", + ) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + parser.add_argument( + "--fixed-embeddings", + action="store_true", + help="keep embeddings fixed (ENCODER ONLY)", + ) # TODO Also apply to decoder embeddings? + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument( + "--encoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for encoder input embedding", + ) + parser.add_argument( + "--encoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for encoder output", + ) + parser.add_argument( + "--decoder-dropout-in", + type=float, + metavar="D", + help="dropout probability for decoder input embedding", + ) + parser.add_argument( + "--decoder-dropout-out", + type=float, + metavar="D", + help="dropout probability for decoder output", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + pretrained_encoder_embed = None + if args.encoder_embed_path: + pretrained_encoder_embed = load_pretrained_embedding_from_file( + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim + ) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LSTMEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_size=args.encoder_hidden_size, + num_layers=args.encoder_layers, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + bidirectional=args.encoder_bidirectional, + pretrained_embed=pretrained_encoder_embed, + fixed_embeddings=args.fixed_embeddings, + ) + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + zero_init=options.eval_bool(args.decoder_zero_init), + encoder_embed_dim=args.encoder_embed_dim, + encoder_output_units=encoder.output_units, + pretrained_embed=pretrained_decoder_embed, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + return cls(encoder, decoder) + + +class LSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_value=0.0, + fixed_embeddings=False, + ): + super().__init__(dictionary) + self.num_layers = num_layers + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + else: + self.embed_tokens = pretrained_embed + if fixed_embeddings: + self.embed_tokens.weight.requires_grad = False + + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_tokens, src_lengths, dataset_name): + if self.left_pad: + # convert left-padding to right-padding + src_tokens = utils.convert_padding_direction( + src_tokens, + self.padding_idx, + left_to_right=True, + ) + + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + try: + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + except BaseException: + raise Exception(f"Packing failed in dataset {dataset_name}") + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value + ) + x = F.dropout(x, p=self.dropout_out, training=self.training) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return torch.cat( + [ + torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( + 1, bsz, self.output_units + ) + for i in range(self.num_layers) + ], + dim=0, + ) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + # Set padded outputs to -inf so they are not selected by max-pooling + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + return { + "sentemb": sentemb, + "encoder_out": (x, final_hiddens, final_cells), + "encoder_padding_mask": encoder_padding_mask + if encoder_padding_mask.any() + else None, + } + + def reorder_encoder_out(self, encoder_out_dict, new_order): + encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select( + 0, new_order + ) + encoder_out_dict["encoder_out"] = tuple( + eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"] + ) + if encoder_out_dict["encoder_padding_mask"] is not None: + encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out_dict + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) # an arbitrary large number + + +class LSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + zero_init=False, + encoder_embed_dim=512, + encoder_output_units=512, + pretrained_embed=None, + num_langs=1, + lang_embed_dim=0, + ): + super().__init__(dictionary) + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.hidden_size = hidden_size + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=encoder_output_units + embed_dim + lang_embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + if zero_init: + self.sentemb2init = None + else: + self.sentemb2init = Linear( + encoder_output_units, 2 * num_layers * hidden_size + ) + + if lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(num_langs, lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + def forward( + self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0 + ): + sentemb = encoder_out_dict["sentemb"] + encoder_out = encoder_out_dict["encoder_out"] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + # get outputs from encoder + encoder_outs, _, _ = encoder_out[:3] + srclen = encoder_outs.size(0) + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # embed language identifier + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + # TODO Should we dropout here??? + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is not None: + prev_hiddens, prev_cells, input_feed = cached_state + else: + num_layers = len(self.layers) + if self.sentemb2init is None: + prev_hiddens = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + prev_cells = [ + x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) + ] + else: + init = self.sentemb2init(sentemb) + prev_hiddens = [ + init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size] + for i in range(num_layers) + ] + prev_cells = [ + init[ + :, + (2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size, + ] + for i in range(num_layers) + ] + input_feed = x.data.new(bsz, self.hidden_size).zero_() + + attn_scores = x.data.new(srclen, seqlen, bsz).zero_() + outs = [] + for j in range(seqlen): + if self.embed_lang is None: + input = torch.cat((x[j, :, :], sentemb), dim=1) + else: + input = torch.cat((x[j, :, :], sentemb, langemb), dim=1) + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # hidden state becomes the input to the next layer + input = F.dropout(hidden, p=self.dropout_out, training=self.training) + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + out = hidden + out = F.dropout(out, p=self.dropout_out, training=self.training) + + # input feeding + input_feed = out + + # save final output + outs.append(out) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, + incremental_state, + "cached_state", + (prev_hiddens, prev_cells, input_feed), + ) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + attn_scores = attn_scores.transpose(0, 2) + + # project back to size of vocabulary + if hasattr(self, "additional_fc"): + x = self.additional_fc(x) + x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.fc_out(x) + + return x, attn_scores + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state(self, incremental_state, "cached_state", new_state) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return int(1e5) # an arbitrary large number + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(m.weight, -0.1, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + + +@register_model_architecture("laser_lstm", "laser_lstm") +def base_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.decoder_zero_init = getattr(args, "decoder_zero_init", "0") + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) + args.fixed_embeddings = getattr(args, "fixed_embeddings", False) diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py new file mode 100644 index 0000000000..c8ac805f54 --- /dev/null +++ b/examples/laser/laser_src/laser_task.py @@ -0,0 +1,326 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from collections import OrderedDict, defaultdict +import json +import os +import logging + +from fairseq import options, models +from fairseq.data import ( + data_utils, + Dictionary, + LanguagePairDataset, + IndexedDataset, + FairseqDataset, +) +from .multitask_data_utils import ( + MultitaskDatasetWrapper, + MultidatasetEpochBatchIterator, +) + + +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("laser") +class LaserTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "configfile", metavar="PATH", help="dataset configuration file in json" + ) + parser.add_argument( + "--weighting-alpha", + type=float, + default=None, + help="alpha for automatic weighting", + ) + parser.add_argument( + "--raw-text", action="store_true", help="load raw text dataset" + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left (default: True)", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left (default: False)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): + super().__init__(args) + self.config = config + self.src_dictionary = src_dictionary + self.tgt_dictionary = tgt_dictionary + self.num_tasks = num_tasks + + @classmethod + def setup_task(cls, args, **kwargs): + with open(args.configfile, "r") as f: + config = json.load(f) + num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 + + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + src_dictionary = Dictionary.load(config["src_vocab"]) + tgt_dictionary = Dictionary.load(config["tgt_vocab"]) + + logger.info( + "| src Dictionary {} : {} types".format( + config["src_vocab"], len(src_dictionary) + ) + ) + logger.info( + "| tgt Dictionary {} : {} types".format( + config["tgt_vocab"], len(tgt_dictionary) + ) + ) + + return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) + + # Experimental overriding for backtranslation + def build_model(self, args): + model = models.build_model(args, self) + return model + + def dataset(self, split): + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + return self.datasets[split] + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + + def indexed_dataset(path, dictionary): + if self.args.raw_text: + raise Exception("Unable to handle raw text.") + dataset = IndexedDataset(path, fix_lua_indexing=True) + + return dataset + + pair_datasets = OrderedDict() + + if split == "valid": + self.datasets[split] = pair_datasets + return + + if split not in self.config: + raise FileNotFoundError( + "Dataset not found in config file: {}".format(split) + ) + + size_by_corpus = defaultdict(int) + size_sum = 0 + size_sum_with_subsampling = 0 + init_pair_datasets = {} + + for dataset_config in self.config[split]: + src_path = os.path.dirname(dataset_config["src"]) + corpus_name = src_path.split("/")[-2] + language_pair_name = src_path.split("/")[-1] + pair_datasets_key = corpus_name + "-" + language_pair_name + + logger.info(f"loading... {pair_datasets_key}") + if "src" in dataset_config: + src_dataset = indexed_dataset( + dataset_config["src"], self.src_dictionary + ) + else: + src_dataset = None + + if "tgt" in dataset_config: + tgt_dataset = indexed_dataset( + dataset_config["tgt"], self.tgt_dictionary + ) + else: + tgt_dataset = None + + dataset = LanguagePairDataset( + src_dataset, + src_dataset.sizes, + self.src_dictionary, + tgt_dataset, + tgt_dataset.sizes, + self.tgt_dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + + if pair_datasets_key in init_pair_datasets: + logger.warning( + f"Ignoring already added {pair_datasets_key}. " + f"Consider using `sample` key in order to upsample." + ) + else: + init_pair_datasets[pair_datasets_key] = { + "dataset": dataset, + "sample": dataset_config.get("sample", None), + "id": dataset_config.get("id", None), + "len": len(dataset), + } + + length_sum = 0 + weighted_freqs_sum = 0 + freq_per_dataset = {} + vmax = 0 + vmin = 1 + weighted_freq_per_dataset = {} + + if self.args.weighting_alpha: + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + length_sum += len(init_pair_datasets[key]["dataset"]) + + for key in init_pair_datasets: + if init_pair_datasets[key]["sample"] is None: + val = float(init_pair_datasets[key]["len"]) / length_sum + freq_per_dataset[key] = val + weighted_freqs_sum += val ** self.args.weighting_alpha + + for key in freq_per_dataset: + val = ( + freq_per_dataset[key] ** self.args.weighting_alpha + / weighted_freqs_sum + ) + vmin = min(vmin, val) + vmax = max(vmax, val) + weighted_freq_per_dataset[key] = val + + for pair_datasets_key in init_pair_datasets: + dataset_config = init_pair_datasets[pair_datasets_key] + dataset = dataset_config["dataset"] + sample = dataset_config["sample"] + if sample is None: + sample = 1.0 + + if pair_datasets_key in weighted_freq_per_dataset: + w = vmax / weighted_freq_per_dataset[pair_datasets_key] + sample = w + + sample = round(sample) + + initial_sample = sample + initial_pair_datasets_key = pair_datasets_key + + while sample >= 1.0: + assert ( + pair_datasets_key not in pair_datasets + ), f"{pair_datasets_key} already in" + size_sum_with_subsampling += len(dataset) + pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( + dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key + ) + size_sum += len(dataset) + sample -= 1.0 + pair_datasets_key += "-up" + + assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" + + logger.info( + f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" + ) + size_by_corpus[corpus_name] += len(dataset) + + self.datasets[split] = pair_datasets + logger.info( + f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" + ) + + @property + def source_dictionary(self): + return self.src_dictionary + + @property + def target_dictionary(self): + return self.tgt_dictionary + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + # initialize the dataset with the correct starting epoch + for _, dt in dataset.items(): + dt.set_epoch(epoch) + + indices = OrderedDict() + batch_sampler = OrderedDict() + + with data_utils.numpy_seed(seed + epoch): + for key, dt in dataset.items(): + logger.info(f"\t ordered_indices {key}") + indices[key] = dt.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + for key, dt in dataset.items(): + logger.info(f"\t filter_by_size {key}") + indices[key], ignored = dt.filter_indices_by_size( + indices[key], max_positions + ) + + for key, dt in dataset.items(): + logger.info(f"\t batch_by_size {key}") + batch_sampler[key] = data_utils.batch_by_size( + indices[key], + dt.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + epoch_iter = MultidatasetEpochBatchIterator( + dataset=dataset, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + + return epoch_iter diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py new file mode 100644 index 0000000000..0be030994f --- /dev/null +++ b/examples/laser/laser_src/laser_transformer.py @@ -0,0 +1,354 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import Any, Dict, List, Optional +from torch import Tensor + +import torch +import torch.nn as nn + +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + base_architecture, + Embedding, + TransformerModel, + TransformerEncoder, + TransformerDecoder, +) +from fairseq.modules import ( + TransformerDecoderLayer, +) + +logger = logging.getLogger(__name__) + + +@register_model("laser_transformer") +class LaserTransformerModel(FairseqEncoderDecoderModel): + """Train Transformer for LASER task + + Requires --task laser + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens=None, + tgt_tokens=None, + tgt_lengths=None, + target_language_id=-1, + dataset_name="", + ): + laser_encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder( + prev_output_tokens, laser_encoder_out, lang_id=target_language_id + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--decoder-lang-embed-dim", + type=int, + metavar="N", + help="decoder language embedding dimension", + ) + + @classmethod + def build_model(cls, args, task): + base_laser_transformer_architecture(args) + + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + def load_embed_tokens(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + return Embedding(num_embeddings, embed_dim, padding_idx) + + encoder_embed_tokens = load_embed_tokens( + task.source_dictionary, args.encoder_embed_dim + ) + decoder_embed_tokens = load_embed_tokens( + task.target_dictionary, args.decoder_embed_dim + ) + num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 + + encoder = LaserTransformerEncoder( + args, task.source_dictionary, encoder_embed_tokens + ) + + decoder = LaserTransformerDecoder( + args, + task.target_dictionary, + decoder_embed_tokens, + num_langs=num_langs, + lang_embed_dim=args.decoder_lang_embed_dim, + ) + + return cls(encoder, decoder) + + +class LaserTransformerEncoder(TransformerEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, src_tokens, *args, **kwargs): + encoder_out = super().forward(src_tokens, *args, **kwargs) + + x = encoder_out["encoder_out"][0] # T x B x C + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) + + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + + # Build the sentence embedding by max-pooling over the encoder outputs + sentemb = x.max(dim=0)[0] + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return {"sentemb": [sentemb]} # B x C + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Same as the one in transformer.py, with new_sentemb + """ + if len(encoder_out["sentemb"]) == 0: + new_sentemb = [] + else: + new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)] + + return { + "sentemb": new_sentemb, # B x C + } + + +class LaserTransformerDecoder(TransformerDecoder): + def __init__(self, args, dictionary, *kargs, **kwargs): + self.num_langs = kwargs.get("num_langs", 1) + self.lang_embed_dim = kwargs.get("lang_embed_dim", 0) + kwargs.pop("num_langs", None) + kwargs.pop("lang_embed_dim", None) + + super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True) + + if self.lang_embed_dim == 0: + self.embed_lang = None + else: + self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim) + nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) + + if self.output_projection is not None: + laser_output_embed_dim = ( + self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + self.output_projection = nn.Linear( + laser_output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, + mean=0, + std=laser_output_embed_dim ** -0.5, + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + decoder_embed_dim = args.decoder_embed_dim + args.decoder_embed_dim = ( + decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim + ) + res = TransformerDecoderLayer(args, no_encoder_attn=True) + args.decoder_embed_dim = decoder_embed_dim + + return res + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + lang_id: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + bsz, seqlen = prev_output_tokens.size() + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + if self.embed_lang is not None: + lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) + langemb = self.embed_lang(lang_ids) + langemb = langemb.unsqueeze(0) + repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * ( + len(langemb.shape) - 1 + ) + x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1) + + sentemb = encoder_out["sentemb"][0] + sentemb = sentemb.unsqueeze(0) + + repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1) + x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + None, + None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + lang_id: Optional[int] = None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + assert lang_id is not None + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + lang_id=lang_id, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + +@register_model_architecture("laser_transformer", "laser_transformer") +def base_laser_transformer_architecture(args): + base_architecture(args) + args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) diff --git a/examples/laser/laser_src/multitask_data_utils.py b/examples/laser/laser_src/multitask_data_utils.py new file mode 100644 index 0000000000..b05caea267 --- /dev/null +++ b/examples/laser/laser_src/multitask_data_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import numpy as np + +from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators + + +class MultiItr(object): + def __init__(self, itr): + self.itr = itr + self._counts = [0 for x in itr] + + def __len__(self): + return sum(len(itr) for itr in self.itr) + + def __iter__(self): + return self + + def __next__(self): + ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)] + idx = ratios.index(min(ratios)) + self._counts[idx] += 1 + return next(self.itr[idx]) + + +class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating): + """A wrapper around multiple epoch batch iterators.""" + + def __init__( + self, + dataset, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + ): + + assert isinstance(dataset, OrderedDict) + assert len(dataset) + assert isinstance(dataset[next(iter(dataset))], FairseqDataset) + + self.iterators = [] + + self.epoch = epoch + for key, dt in dataset.items(): + epoch_iter = iterators.EpochBatchIterator( + dataset=dt, + collate_fn=dt.collater, + batch_sampler=batch_sampler[key], + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=0, + epoch=epoch, + ) + self.iterators.append(epoch_iter) + + def __len__(self): + return sum(len(itr) for itr in self.iterators) + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s. + return MultiItr( + [ + itr.next_epoch_itr( + shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus + ) + for itr in self.iterators + ] + ) + + def end_of_epoch(self): + return all(itr.end_of_epoch() for itr in self.iterators) + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + + epochs = [itr.next_epoch_idx for itr in self.iterators] + self.epoch = epochs[0] + assert all(epoch == self.epoch for epoch in epochs) + + return self.epoch + + @property + def iterations_in_epoch(self): + return sum(itr.iterations_in_epoch for itr in self.iterators) + + def state_dict(self): + return { + "iterators": [it.state_dict() for it in self.iterators], + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + for it, d in zip(self.iterators, state_dict["iterators"]): + it.load_state_dict(d) + + +class MultitaskDatasetWrapper(BaseWrapperDataset): + """A wrapper for a multitask dataset.""" + + def __init__(self, dataset, target_language_id, sample=1.0, name=""): + super().__init__(dataset) + self.target_language_id = target_language_id + self.sample = sample + self.name = name + + def collater(self, *args, **kwargs): + ans = self.dataset.collater(*args, **kwargs) + if "net_input" in ans: + ans["net_input"]["target_language_id"] = self.target_language_id + ans["net_input"]["dataset_name"] = self.name + return ans + + def num_tokens(self, *args, **kwargs): + return self.dataset.num_tokens(*args, **kwargs) + + def ordered_indices(self, *args, **kwargs): + indices = self.dataset.ordered_indices(*args, **kwargs) + # Hacky solution for sampling + size = int(self.sample * indices.shape[0]) + + return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size])) + + def size(self, index: int): + return self.dataset.size(index) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/examples/linformer/linformer_src/models/linformer_roberta.py b/examples/linformer/linformer_src/models/linformer_roberta.py index be5d8e85ec..18ad44f079 100644 --- a/examples/linformer/linformer_src/models/linformer_roberta.py +++ b/examples/linformer/linformer_src/models/linformer_roberta.py @@ -11,9 +11,15 @@ import torch from fairseq import utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.roberta import RobertaEncoder, RobertaModel +from fairseq.models.roberta import ( + init_bert_params, + roberta_base_architecture, + roberta_large_architecture, + RobertaEncoder, + RobertaModel, +) -from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder +from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder logger = logging.getLogger(__name__) @@ -66,30 +72,10 @@ def __init__(self, args, dictionary): super().__init__(args, dictionary) self.register_buffer("version", torch.tensor(2)) - def build_encoder(self, args, dictionary): - return LinformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - compressed=args.compressed, - shared_kv_compressed=args.shared_kv_compressed, - shared_layer_kv_compressed=args.shared_layer_kv_compressed, - freeze_compress=args.freeze_compress, - ) + def build_encoder(self, args, dictionary, embed_tokens): + encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -115,25 +101,11 @@ def upgrade_state_dict_named(self, state_dict, name): @register_model_architecture("linformer_roberta", "linformer_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.compressed = getattr(args, "compressed", 4) args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) args.freeze_compress = getattr(args, "freeze_compress", 0) + roberta_base_architecture(args) @register_model_architecture("linformer_roberta", "linformer_roberta_base") @@ -143,18 +115,5 @@ def linformer_roberta_base_architecture(args): @register_model_architecture("linformer_roberta", "linformer_roberta_large") def linformer_roberta_large_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 24) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.compressed = getattr(args, "compressed", 4) - args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0) - args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0) + roberta_large_architecture(args) + base_architecture(args) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py index 3cdca01235..44f7989bd8 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py @@ -6,12 +6,12 @@ import math import torch.nn as nn -from fairseq.modules import TransformerSentenceEncoder +from fairseq.models.transformer import TransformerEncoder -from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer +from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer -class LinformerSentenceEncoder(TransformerSentenceEncoder): +class LinformerTransformerEncoder(TransformerEncoder): """ Implementation for a Bi-directional Linformer based Sentence Encoder used in BERT/XLM style pre-trained models. @@ -35,135 +35,20 @@ class LinformerSentenceEncoder(TransformerSentenceEncoder): in format B x C. """ - def __init__( - self, - padding_idx: int, - vocab_size: int, - num_encoder_layers: int = 6, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - layerdrop: float = 0.0, - max_seq_len: int = 256, - num_segments: int = 2, - use_position_embeddings: bool = True, - offset_positions_by_padding: bool = True, - encoder_normalize_before: bool = False, - apply_bert_init: bool = False, - activation_fn: str = "relu", - learned_pos_embedding: bool = True, - embed_scale: float = None, - freeze_embeddings: bool = False, - n_trans_layers_to_freeze: int = 0, - export: bool = False, - traceable: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - compressed: int = 4, - shared_kv_compressed: int = 0, - shared_layer_kv_compressed: int = 0, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.shared_kv_compressed = shared_kv_compressed - self.shared_layer_kv_compressed = shared_layer_kv_compressed + def __init__(self, args, dictionary, embed_tokens): self.compress_layer = None - self.freeze_compress = freeze_compress - - super().__init__( - padding_idx=padding_idx, - vocab_size=vocab_size, - num_encoder_layers=num_encoder_layers, - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - layerdrop=layerdrop, - max_seq_len=max_seq_len, - num_segments=num_segments, - use_position_embeddings=use_position_embeddings, - offset_positions_by_padding=offset_positions_by_padding, - encoder_normalize_before=encoder_normalize_before, - apply_bert_init=apply_bert_init, - activation_fn=activation_fn, - learned_pos_embedding=learned_pos_embedding, - embed_scale=embed_scale, - freeze_embeddings=freeze_embeddings, - n_trans_layers_to_freeze=n_trans_layers_to_freeze, - export=export, - traceable=traceable, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args, dictionary, embed_tokens) - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - q_noise, - qn_block_size, - ): - if self.shared_layer_kv_compressed == 1 and self.compress_layer is None: + def build_encoder_layer(self, args): + if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None: compress_layer = nn.Linear( - self.max_seq_len, self.max_seq_len // self.compressed + self.args.max_positions, + self.args.max_positions // self.args.compressed, ) # intialize parameters for compressed layer nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2)) - if self.freeze_compress == 1: + if self.args.freeze_compress == 1: compress_layer.weight.requires_grad = False self.compress_layer = compress_layer - return LinformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, - shared_compress_layer=( - None if self.shared_layer_kv_compressed == 0 else self.compress_layer - ), - freeze_compress=self.freeze_compress, - ) - - def upgrade_state_dict_named(self, state_dict, name): - prefix = name + "." if name != "" else "" - items_to_add = {} - keys_to_remove = [] - - # update key name for shared layer in new version of code - for k in state_dict.keys(): - if k.startswith(prefix + "compress_layer"): - if self.shared_layer_kv_compressed: - for layer_idx in range(len(self.layers)): - new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format( - layer_idx, - k[len(prefix + "compress_layer.") :], - ) - items_to_add[new_k] = state_dict[k] - - for k in keys_to_remove: - del state_dict[k] - - for key, value in items_to_add.items(): - state_dict[key] = value + return LinformerTransformerEncoderLayer(args, self.compress_layer) diff --git a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py index 0b80fabefe..7e2caa0340 100644 --- a/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +++ b/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py @@ -3,88 +3,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable - import torch from fairseq import utils -from fairseq.modules import TransformerSentenceEncoderLayer +from fairseq.modules import TransformerEncoderLayer from .multihead_linear_attention import MultiheadLinearAttention -class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): +class LinformerTransformerEncoderLayer(TransformerEncoderLayer): """ Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained models. """ - def __init__( - self, - embedding_dim: int = 768, - ffn_embedding_dim: int = 3072, - num_attention_heads: int = 8, - dropout: float = 0.1, - attention_dropout: float = 0.1, - activation_dropout: float = 0.1, - activation_fn: str = "relu", - export: bool = False, - q_noise: float = 0.0, - qn_block_size: int = 8, - init_fn: Callable = None, - compressed: int = 1, - max_seq_len: int = 256, - shared_kv_compressed: int = 0, - shared_compress_layer: any = None, - freeze_compress: int = 0, - ) -> None: - - # Initialize linformer parameters - self.compressed = compressed - self.max_seq_len = max_seq_len - self.shared_kv_compressed = shared_kv_compressed - self.freeze_compress = freeze_compress - + def __init__(self, args, shared_compress_layer): # wrap in a list so it's not automatically registered by PyTorch self.shared_compress_layer = [shared_compress_layer] - super().__init__( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - q_noise=q_noise, - qn_block_size=qn_block_size, - ) + super().__init__(args) + self.register_buffer("version", torch.tensor(2)) - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - self_attention, - q_noise, - qn_block_size, - ): + def build_self_attention(self, embed_dim, args): return MultiheadLinearAttention( embed_dim, - num_attention_heads, - dropout=dropout, + args.encoder_attention_heads, + dropout=args.dropout, self_attention=True, - q_noise=q_noise, - qn_block_size=qn_block_size, - compressed=self.compressed, - max_seq_len=self.max_seq_len, - shared_kv_compressed=self.shared_kv_compressed, + q_noise=args.quant_noise_pq, + qn_block_size=args.quant_noise_pq_block_size, + compressed=args.compressed, + max_seq_len=args.max_positions, + shared_kv_compressed=args.shared_kv_compressed, shared_compress_layer=self.shared_compress_layer[0], - freeze_compress=self.freeze_compress, + freeze_compress=args.freeze_compress, ) def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) prefix = name + "." if name != "" else "" # some old checkpoints had weight sharing implemented incorrectly @@ -101,14 +57,7 @@ def upgrade_state_dict_named(self, state_dict, name): self.shared_compress_layer[0].weight.size(0), ) ] - self.self_attn = self.build_self_attention( - self.embedding_dim, - self.num_attention_heads, - dropout=self.attention_dropout, - self_attention=True, - q_noise=self.q_noise, - qn_block_size=self.qn_block_size, - ) + self.self_attn = self.build_self_attention(self.embed_dim, self.args) # delete shared_compress_layer, since it's already copied to # self_attn.compress_k.weight del state_dict[f"{prefix}shared_compress_layer.weight"] diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index fb40a80836..e109a8e269 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -4,13 +4,12 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List, Tuple import torch import torch.nn as nn from fairseq import metrics, utils from fairseq.models import register_model, register_model_architecture -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS, @@ -155,7 +154,13 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder): to the decoder. """ - def forward(self, src_tokens, src_lengths, **kwargs): + def forward( + self, + src_tokens, + src_lengths: Optional[Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[Tensor] = None + ): """ Runs the `forward()` method of the parent Transformer class. Then adds the source tokens into the encoder output tuple. @@ -169,6 +174,10 @@ def forward(self, src_tokens, src_lengths, **kwargs): shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings Returns: namedtuple: @@ -184,7 +193,15 @@ def forward(self, src_tokens, src_lengths, **kwargs): - **src_tokens** (Tensor): input token ids of shape `(batch, src_len)` """ - encoder_out = super().forward(src_tokens, src_lengths, **kwargs) + encoder_out = self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. return { "encoder_out": encoder_out["encoder_out"], # T x B x C "encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T @@ -236,7 +253,7 @@ def __init__(self, args, dictionary, embed_tokens): def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, alignment_layer: Optional[int] = 0, @@ -248,8 +265,8 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing - encoder_out (EncoderOut, optional): output from the encoder, used - for encoder-side attention + encoder_out (optional): output from the encoder, used for + encoder-side attention incremental_state (dict, optional): dictionary used for storing state during :ref:`Incremental decoding` features_only (bool, optional): only return features without @@ -284,10 +301,21 @@ def forward( predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) p_gens = torch.sigmoid(p_gens) - x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens) + # Torchscript complains if encoder_out or attn are None because + # `output_layer()` signature expects tensors instead + attn: Optional[Tensor] = extra["attn"][0] + assert encoder_out is not None + assert attn is not None + x = self.output_layer(x, attn, encoder_out["src_tokens"][0], p_gens) return x, extra - def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): + def output_layer( + self, + features: Tensor, + attn: Tensor, + src_tokens: Tensor, + p_gens: Tensor + ) -> Tensor: """ Project features to the vocabulary size and mix with the attention distributions. @@ -296,7 +324,10 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): p_gens = self.force_p_gen # project back to size of vocabulary - logits = super().output_layer(features, **kwargs) + if self.adaptive_softmax is None: + logits = self.output_projection(features) + else: + logits = features batch_size = logits.shape[0] output_length = logits.shape[1] @@ -306,7 +337,7 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # The final output distribution will be a mixture of the normal output # distribution (softmax of logits) and attention weights. - gen_dists = super().get_normalized_probs( + gen_dists = self.get_normalized_probs_scriptable( (logits, None), log_probs=False, sample=None ) gen_dists = torch.mul(gen_dists, p_gens) @@ -330,7 +361,12 @@ def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): # Final distributions, [batch_size, output_length, num_types]. return gen_dists + attn_dists - def get_normalized_probs(self, net_output, log_probs, sample): + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): """ Get normalized probabilities (or log probs) from a net's output. Pointer-generator network output is already normalized. @@ -375,8 +411,19 @@ class Embedding(nn.Embedding): """ __constants__ = ["unk_idx"] - def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx): - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + # Torchscript: Inheriting from Embedding class produces an error when exporting to Torchscript + # -> RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details + # It's happening because max_norm attribute from nn.Embedding is None by default and it cannot be + # cast to a C++ type + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int], + unk_idx: int, + max_norm: Optional[float] = float("inf"), + ): + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm) self.unk_idx = unk_idx nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) nn.init.constant_(self.weight[padding_idx], 0) @@ -385,7 +432,10 @@ def forward(self, input): input = torch.where( input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input ) - return super().forward(input) + return nn.functional.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) @register_model_architecture( diff --git a/examples/simultaneous_translation/__init__.py b/examples/simultaneous_translation/__init__.py index 446fc86c8a..5835316ba9 100644 --- a/examples/simultaneous_translation/__init__.py +++ b/examples/simultaneous_translation/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import criterions, eval, models # noqa +from . import models # noqa diff --git a/examples/simultaneous_translation/criterions/__init__.py b/examples/simultaneous_translation/criterions/__init__.py deleted file mode 100644 index 08791bfff3..0000000000 --- a/examples/simultaneous_translation/criterions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - criterion_name = file[: file.find(".py")] - importlib.import_module( - "examples.simultaneous_translation.criterions." + criterion_name - ) diff --git a/examples/simultaneous_translation/docs/baseline.md b/examples/simultaneous_translation/docs/baseline.md deleted file mode 100644 index d9bf1a1117..0000000000 --- a/examples/simultaneous_translation/docs/baseline.md +++ /dev/null @@ -1,178 +0,0 @@ -# **Baseline Simultaneous Translation** ---- - -This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset. - -[STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/) - - -## **Requirements** -Install fairseq (make sure to use the correct branch): -``` -git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git -cd fairseq -pip install -e . -``` - -Assuming that fairseq is installed in a directory called `FAIRSEQ`. - -Install SentencePiece. One easy way is to use anaconda: - -``` -conda install -c powerai sentencepiece -``` - -Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/. -We will assume that the data is downloaded in a directory called `DATA_ROOT`. - - -## **Text-to-text Model** ---- -### Data Preparation -Train a SentencePiece model: -```shell -for lang in en de; do - python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` - -Process the data with the SentencePiece model: -```shell -proc_dir=proc -mkdir -p $proc_dir -for split in train dev tst-COMMON tst-HE; do - for lang in en de; do - spm_encode \ - --model unigram-$lang-10000-3000/spm.model \ - < $DATA_ROOT/data/$split/txt/$split.$lang \ - > $proc_dir/$split.spm.$lang - done -done -``` - -Binarize the data: - -```shell -proc_dir=proc -fairseq-preprocess \ - --source-lang en --target-lang de \ - --trainpref $proc_dir/train.spm \ - --validpref $proc_dir/dev.spm \ - --testpref $proc_dir/tst-COMMON.spm \ - --thresholdtgt 0 \ - --thresholdsrc 0 \ - --workers 20 \ - --destdir ./data-bin/mustc_en_de \ -``` - -### Training - - -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --simul-type waitk \ - --waitk-lagging 2 \ - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## **Speech-to-text Model** ---- -### Data Preparation -First, segment wav files. -```shell -python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \ - --datapath $DATA_ROOT -``` -Similar to text-to-text model, train a Sentencepiecemodel, but only train on German -```Shell -python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \ - --data-path $DATA_ROOT/data \ - --vocab-size 10000 \ - --max-frame 3000 \ - --model-type unigram \ - --lang $lang \ - --out-path . -``` -## Training -```shell -mkdir -p checkpoints -CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \ - --save-dir checkpoints \ - --arch berard_simul_text_iwslt \ - --waitk-lagging 2 \ - --waitk-stride 10 \ - --input-feat-per-channel 40 \ - --encoder-hidden-size 512 \ - --output-layer-dim 128 \ - --decoder-num-layers 3 \ - --task speech_translation \ - --user-dir $FAIRSEQ/examples/simultaneous_translation - --optimizer adam \ - --max-epoch 100 \ - --lr 0.001 \ - --clip-norm 5.0 \ - --batch-size 128 \ - --log-format json \ - --log-interval 10 \ - --criterion cross_entropy_acc \ - --user-dir $FAIRSEQ/examples/simultaneous_translation -``` - -## Evaluation ---- -### Evaluation Server -For text translation models, the server is set up as follow give input file and reference file. - -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --src-file $DATA_ROOT/data/dev/txt/dev.en \ - --ref-file $DATA_ROOT/data/dev/txt/dev.de -``` -For speech translation models, the input is the data direcrory. -``` shell -python ./eval/server.py \ - --hostname localhost \ - --port 12321 \ - --ref-file $DATA_ROOT \ - --data-type speech -``` - -### Decode and Evaluate with Client -Once the server is set up, run client to evaluate translation quality and latency. -```shell -# TEXT -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --src-spm unigram-en-10000-3000/spm.model\ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt - -# SPEECH -python $fairseq_dir/examples/simultaneous_translation/evaluate.py \ - data-bin/mustc_en_de \ - --user-dir $FAIRSEQ/examples/simultaneous_translation \ - --data-type speech \ - --tgt-spm unigram-de-10000-3000/spm.model\ - -s en -t de \ - --path checkpoints/checkpoint_best.pt -``` diff --git a/examples/simultaneous_translation/docs/evaluation.md b/examples/simultaneous_translation/docs/evaluation.md deleted file mode 100644 index c53407354e..0000000000 --- a/examples/simultaneous_translation/docs/evaluation.md +++ /dev/null @@ -1,115 +0,0 @@ -# Introduction to evaluation interface -The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file. - -## Server-Client Protocol -An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure. - -While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence. - -### Server -The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set, - -```shell - - python fairseq/examples/simultaneous_translation/eval/server.py \ - --hostname local_host \ - --port 1234 \ - --src-file SRC_FILE \ - --ref-file REF_FILE \ - --data-type text \ -``` -The state that server sent to client is has the following format -```json -{ - 'sent_id': Int, - 'segment_id': Int, - 'segment': String -} -``` - -### Client -The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table - -|Action|Content| -|:---:|:---:| -|Request new word / utterence| ```{key: "Get", value: None}```| -|Predict word "W"| ```{key: "SEND", value: "W"}```| - - - -The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function. -```python -class Agent(object): - "an agent needs to follow this pattern" - def __init__(self, *args, **kwargs): - ... - - def init_states(self): - # Initializing states - ... - - def update_states(self, states, new_state): - # Update states with given new state from server - # TODO (describe the states) - ... - - def finish_eval(self, states, new_state): - # Check if evaluation is finished - ... - - def policy(self, state: list) -> dict: - # Provide a action given current states - # The action can only be either - # {key: "GET", value: NONE} - # or - # {key: "SEND", value: W} - ... - - def reset(self): - # Reset agent - ... - - def decode(self, session): - - states = self.init_states() - self.reset() - - # Evaluataion protocol happens here - while True: - # Get action from the current states according to self.policy() - action = self.policy(states) - - if action['key'] == GET: - # Read a new state from server - new_state = session.get_src() - states = self.update_states(states, new_state) - - if self.finish_eval(states, new_state): - # End of document - break - - elif action['key'] == SEND: - # Send a new prediction to server - session.send_hypo(action['value']) - - # Clean the history, wait for next sentence - if action['value'] == DEFAULT_EOS: - states = self.init_states() - self.reset() - else: - raise NotImplementedError - - -``` -Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered. - -## Quality -The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link) - -## Latency -The latency metrics are -* Average Proportion -* Average Lagging -* Differentiable Average Lagging -Again Thery will also be evaluated on detokenized text. - diff --git a/examples/simultaneous_translation/models/convtransformer_simul_trans.py b/examples/simultaneous_translation/models/convtransformer_simul_trans.py new file mode 100644 index 0000000000..0b15e93fea --- /dev/null +++ b/examples/simultaneous_translation/models/convtransformer_simul_trans.py @@ -0,0 +1,155 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from fairseq import checkpoint_utils +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_text import ( + ConvTransformerModel, + convtransformer_espnet, + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.augmented_memory_attention import ( + augmented_memory, + SequenceEncoder, + AugmentedMemoryConvTransformerEncoder, +) +from fairseq.models.speech_to_text.modules.emformer import emformer_encoder + + +@register_model("convtransformer_simul_trans") +class SimulConvTransformerModel(ConvTransformerModel): + """ + Implementation of the paper: + + SimulMT to SimulST: Adapting Simultaneous Text Translation to + End-to-End Simultaneous Speech Translation + + https://www.aclweb.org/anthology/2020.aacl-main.58.pdf + """ + + @staticmethod + def add_args(parser): + super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser) + parser.add_argument( + "--train-monotonic-only", + action="store_true", + default=False, + help="Only train monotonic attention", + ) + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + tgt_dict = task.tgt_dict + + from examples.simultaneous_translation.models.transformer_monotonic_attention import ( + TransformerMonotonicDecoder, + ) + + decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens) + + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + +@register_model_architecture( + "convtransformer_simul_trans", "convtransformer_simul_trans_espnet" +) +def convtransformer_simul_trans_espnet(args): + convtransformer_espnet(args) + + +@register_model("convtransformer_augmented_memory") +@augmented_memory +class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel): + @classmethod + def build_encoder(cls, args): + encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args)) + + if getattr(args, "load_pretrained_encoder_from", None) is not None: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + + return encoder + + +@register_model_architecture( + "convtransformer_augmented_memory", "convtransformer_augmented_memory" +) +def augmented_memory_convtransformer_espnet(args): + convtransformer_espnet(args) + + +# ============================================================================ # +# Convtransformer +# with monotonic attention decoder +# with emformer encoder +# ============================================================================ # + + +@emformer_encoder +class ConvTransformerEmformerEncoder(ConvTransformerEncoder): + pass + + +@register_model("convtransformer_emformer") +class ConvtransformerEmformer(SimulConvTransformerModel): + @staticmethod + def add_args(parser): + super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser) + + parser.add_argument( + "--segment-length", + type=int, + metavar="N", + help="length of each segment (not including left context / right context)", + ) + parser.add_argument( + "--segment-left-context", + type=int, + help="length of left context in a segment", + ) + parser.add_argument( + "--segment-right-context", + type=int, + help="length of right context in a segment", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + parser.add_argument( + "--amtrf-tanh-on-mem", + default=False, + action="store_true", + help="whether to use tanh on memory vector", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEmformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + +@register_model_architecture( + "convtransformer_emformer", + "convtransformer_emformer", +) +def convtransformer_emformer_base(args): + convtransformer_espnet(args) diff --git a/examples/simultaneous_translation/models/transformer_monotonic_attention.py b/examples/simultaneous_translation/models/transformer_monotonic_attention.py index ab8adf3aab..65c12c6f5b 100644 --- a/examples/simultaneous_translation/models/transformer_monotonic_attention.py +++ b/examples/simultaneous_translation/models/transformer_monotonic_attention.py @@ -10,17 +10,20 @@ TransformerMonotonicDecoderLayer, TransformerMonotonicEncoderLayer, ) -from fairseq.models import register_model, register_model_architecture +from fairseq.models import ( + register_model, + register_model_architecture, +) from fairseq.models.transformer import ( - TransformerDecoder, - TransformerEncoder, TransformerModel, + TransformerEncoder, + TransformerDecoder, base_architecture, transformer_iwslt_de_en, transformer_vaswani_wmt_en_de_big, + transformer_vaswani_wmt_en_fr_big, ) - DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -33,7 +36,7 @@ def build_encoder(cls, args, src_dict, embed_tokens): @register_model("transformer_monotonic") -class TransformerMonotonicModel(TransformerModel): +class TransformerModelSimulTrans(TransformerModel): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): return TransformerMonotonicEncoder(args, src_dict, embed_tokens) @@ -62,60 +65,6 @@ def _indices_from_states(self, states): return src_indices, None, tgt_indices - def predict_from_states(self, states): - decoder_states = self.decoder.output_layer(states["decoder_features"]) - lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True) - - index = lprobs.argmax(dim=-1) - - token = self.decoder.dictionary.string(index) - - return token, index[0, 0].item() - - def decision_from_states(self, states): - """ - This funcion take states dictionary as input, and gives the agent - a decision of whether read a token from server. Moreover, the decoder - states are also calculated here so we can directly generate a target - token without recompute every thing - """ - - self.eval() - - if len(states["tokens"]["src"]) == 0: - return 0 - - src_indices, src_lengths, tgt_indices = self._indices_from_states(states) - - # Update encoder states if needed - if ( - "encoder_states" not in states - or states["encoder_states"][0].size(1) <= states["steps"]["src"] - ): - encoder_out_dict = self.encoder(src_indices, src_lengths) - states["encoder_states"] = encoder_out_dict - else: - encoder_out_dict = states["encoder_states"] - - # online means we still need tokens to feed the model - states["model_states"]["online"] = not ( - states["finish_read"] - and len(states["tokens"]["src"]) == states["steps"]["src"] - ) - - states["model_states"]["steps"] = states["steps"] - - x, outputs = self.decoder.forward( - prev_output_tokens=tgt_indices, - encoder_out=encoder_out_dict, - incremental_state=states["model_states"], - features_only=True, - ) - - states["decoder_features"] = x - - return outputs["action"] - class TransformerMonotonicEncoder(TransformerEncoder): def __init__(self, args, dictionary, embed_tokens): @@ -178,13 +127,18 @@ def pre_attention( if positions is not None: x += positions + x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) - encoder_out = encoder_out_dict.encoder_out - encoder_padding_mask = encoder_out_dict.encoder_padding_mask + encoder_out = encoder_out_dict["encoder_out"][0] + encoder_padding_mask = ( + encoder_out_dict["encoder_padding_mask"][0] + if len(encoder_out_dict["encoder_padding_mask"]) > 0 + else None + ) return x, encoder_out, encoder_padding_mask @@ -200,6 +154,18 @@ def post_attention(self, x): return x + def clear_cache(self, incremental_state, end_id=None): + """ + Clear cache in the monotonic layers. + The cache is generated because of a forward pass of decode but no prediction. + end_id is the last idx of the layers + """ + if end_id is None: + end_id = len(self.layers) + + for j in range(end_id): + self.layers[j].prune_incremental_state(incremental_state) + def extract_features( self, prev_output_tokens, encoder_out, incremental_state=None, **unused ): @@ -236,12 +202,16 @@ def extract_features( attn_list.append(attn) if incremental_state is not None: - curr_steps = layer.get_steps(incremental_state) + curr_steps = layer.get_head_steps(incremental_state) step_list.append(curr_steps) - if incremental_state.get("online", False): + if incremental_state.get("online", True): + # Online indicates that the encoder states are still changing p_choose = ( - attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t()) + attn["p_choose"] + .squeeze(0) + .squeeze(1) + .gather(1, curr_steps.t()) ) new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps) @@ -250,24 +220,10 @@ def extract_features( # We need to prune the last self_attn saved_state # if model decide not to read # otherwise there will be duplicated saved_state - for j in range(i + 1): - self.layers[j].prune_incremental_state(incremental_state) + self.clear_cache(incremental_state, i + 1) return x, {"action": 0} - if incremental_state is not None and not incremental_state.get("online", False): - # Here is for fast evaluation - fastest_step = ( - torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1 - ) - - if "fastest_step" in incremental_state: - incremental_state["fastest_step"] = torch.cat( - [incremental_state["fastest_step"], fastest_step], dim=1 - ) - else: - incremental_state["fastest_step"] = fastest_step - x = self.post_attention(x) return x, { @@ -287,7 +243,7 @@ def reorder_incremental_state(self, incremental_state, new_order): @register_model_architecture("transformer_monotonic", "transformer_monotonic") -def base_monotonic_rchitecture(args): +def base_monotonic_architecture(args): base_architecture(args) args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False) @@ -297,7 +253,7 @@ def base_monotonic_rchitecture(args): ) def transformer_monotonic_iwslt_de_en(args): transformer_iwslt_de_en(args) - base_monotonic_rchitecture(args) + base_monotonic_architecture(args) # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) diff --git a/examples/simultaneous_translation/modules/fixed_pre_decision.py b/examples/simultaneous_translation/modules/fixed_pre_decision.py new file mode 100644 index 0000000000..cc5e7ad532 --- /dev/null +++ b/examples/simultaneous_translation/modules/fixed_pre_decision.py @@ -0,0 +1,193 @@ +from functools import partial + +import torch +import math +import torch.nn.functional as F + +from . import register_monotonic_attention +from .monotonic_multihead_attention import ( + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionHardAligned, + MonotonicMultiheadAttentionInfiniteLookback, +) + + +def fixed_pooling_monotonic_attention(monotonic_attention): + def create_model(monotonic_attention, klass): + class FixedStrideMonotonicAttention(monotonic_attention): + def __init__(self, args): + super().__init__(args) + self.pre_decision_type = args.fixed_pre_decision_type + self.pre_decision_ratio = args.fixed_pre_decision_ratio + self.pre_decision_pad_threshold = args.fixed_pre_decision_pad_threshold + if self.pre_decision_ratio == 1: + return + + if args.fixed_pre_decision_type == "average": + self.pooling_layer = torch.nn.AvgPool1d( + kernel_size=self.pre_decision_ratio, + stride=self.pre_decision_ratio, + ceil_mode=True, + ) + elif args.fixed_pre_decision_type == "last": + + def last(key): + if key.size(2) < self.pre_decision_ratio: + return key + else: + k = key[ + :, + :, + self.pre_decision_ratio - 1 :: self.pre_decision_ratio, + ].contiguous() + if key.size(-1) % self.pre_decision_ratio != 0: + k = torch.cat([k, key[:, :, -1:]], dim=-1).contiguous() + return k + + self.pooling_layer = last + else: + raise NotImplementedError + + @staticmethod + def add_args(parser): + super( + FixedStrideMonotonicAttention, FixedStrideMonotonicAttention + ).add_args(parser) + parser.add_argument( + "--fixed-pre-decision-ratio", + type=int, + required=True, + help=( + "Ratio for the fixed pre-decision," + "indicating how many encoder steps will start" + "simultaneous decision making process." + ), + ) + parser.add_argument( + "--fixed-pre-decision-type", + default="average", + choices=["average", "last"], + help="Pooling type", + ) + parser.add_argument( + "--fixed-pre-decision-pad-threshold", + type=float, + default=0.3, + help="If a part of the sequence has pad" + ",the threshold the pooled part is a pad.", + ) + + def insert_zeros(self, x): + bsz_num_heads, tgt_len, src_len = x.size() + stride = self.pre_decision_ratio + weight = F.pad(x.new_ones(1, 1, 1), (stride - 1, 0)) + x_upsample = F.conv_transpose1d( + x.view(-1, src_len).unsqueeze(1), + weight, + stride=stride, + padding=0, + ) + return x_upsample.squeeze(1).view(bsz_num_heads, tgt_len, -1) + + def p_choose( + self, + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ): + src_len = key.size(0) + tgt_len = query.size(0) + batch_size = query.size(1) + + if self.pre_decision_ratio == 1: + return super().p_choose( + query, + key, + key_padding_mask=None, + incremental_state=None, + **extra_args + ) + + key_pool = self.pooling_layer(key.transpose(0, 2)).transpose(0, 2) + + if key_padding_mask is not None: + key_padding_mask_pool = ( + self.pooling_layer(key_padding_mask.unsqueeze(0).float()) + .squeeze(0) + .gt(self.pre_decision_pad_threshold) + ) + # Make sure at least one element is not pad + key_padding_mask_pool[:, 0] = 0 + else: + key_padding_mask_pool = None + + if incremental_state is not None: + # The floor instead of ceil is used for inference + # But make sure the length key_pool at least 1 + if ( + max(1, math.floor(key.size(0) / self.pre_decision_ratio)) + ) < key_pool.size(0): + key_pool = key_pool[:-1] + if key_padding_mask_pool is not None: + key_padding_mask_pool = key_padding_mask_pool[:-1] + + p_choose_pooled = super().p_choose( + query, + key_pool, + key_padding_mask_pool, + incremental_state=incremental_state, + ) + + # Upsample, interpolate zeros + p_choose = self.insert_zeros(p_choose_pooled) + + if p_choose.size(-1) < src_len: + # Append zeros if the upsampled p_choose is shorter than src_len + p_choose = torch.cat( + [ + p_choose, + p_choose.new_zeros( + p_choose.size(0), + tgt_len, + src_len - p_choose.size(-1) + ) + ], + dim=2 + ) + else: + # can be larger than src_len because we used ceil before + p_choose = p_choose[:, :, :src_len] + p_choose[:, :, -1] = p_choose_pooled[:, :, -1] + + assert list(p_choose.size()) == [ + batch_size * self.num_heads, + tgt_len, + src_len, + ] + + return p_choose + + FixedStrideMonotonicAttention.__name__ = klass.__name__ + return FixedStrideMonotonicAttention + + return partial(create_model, monotonic_attention) + + +@register_monotonic_attention("waitk_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionWaitK) +class MonotonicMultiheadAttentionWaitkFixedStride: + pass + + +@register_monotonic_attention("hard_aligned_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionHardAligned) +class MonotonicMultiheadAttentionHardFixedStride: + pass + + +@register_monotonic_attention("infinite_lookback_fixed_pre_decision") +@fixed_pooling_monotonic_attention(MonotonicMultiheadAttentionInfiniteLookback) +class MonotonicMultiheadAttentionInfiniteLookbackFixedStride: + pass diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index c09725ac9a..49882afcd8 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn + import torch.nn.functional as F from examples.simultaneous_translation.utils.functions import ( exclusive_cumprod, @@ -30,6 +31,7 @@ def __init__(self, args): self.eps = args.attention_eps self.mass_preservation = args.mass_preservation + self.noise_type = args.noise_type self.noise_mean = args.noise_mean self.noise_var = args.noise_var @@ -43,23 +45,26 @@ def __init__(self, args): @staticmethod def add_args(parser): # fmt: off - parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation", + parser.add_argument('--no-mass-preservation', action="store_false", + dest="mass_preservation", help='Do not stay on the last token when decoding') - parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation", + parser.add_argument('--mass-preservation', action="store_true", + dest="mass_preservation", help='Stay on the last token when decoding') parser.set_defaults(mass_preservation=True) - parser.add_argument('--noise-var', type=float, default=1.0, help='Variance of discretness noise') parser.add_argument('--noise-mean', type=float, default=0.0, help='Mean of discretness noise') - parser.add_argument('--energy-bias', action="store_true", default=False, + parser.add_argument('--noise-type', type=str, default="flat", + help='Type of discretness noise') + parser.add_argument('--energy-bias', action="store_true", + default=False, help='Bias for energy') parser.add_argument('--energy-bias-init', type=float, default=-2.0, help='Initial value of the bias for energy') parser.add_argument('--attention-eps', type=float, default=1e-6, help='Epsilon when calculating expected attention') - # fmt: on def p_choose(self, *args): raise NotImplementedError @@ -67,7 +72,9 @@ def p_choose(self, *args): def input_projections(self, *args): raise NotImplementedError - def attn_energy(self, q_proj, k_proj, key_padding_mask=None): + def attn_energy( + self, q_proj, k_proj, key_padding_mask=None, attn_mask=None + ): """ Calculating monotonic energies @@ -82,7 +89,13 @@ def attn_energy(self, q_proj, k_proj, key_padding_mask=None): bsz = bsz // self.num_heads src_len = k_proj.size(1) - attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + attn_energy = ( + torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias + ) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_energy += attn_mask attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len) @@ -102,7 +115,7 @@ def expected_alignment_train(self, p_choose, key_padding_mask): q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} a_ij = p_ij q_ij - parellel solution: + Parallel solution: ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) ============================================================ @@ -139,21 +152,40 @@ def expected_alignment_train(self, p_choose, key_padding_mask): if self.mass_preservation: # Last token has the residual probabilities - alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) - - assert not torch.isnan(alpha).any(), "NaN detected in alpha." + if key_padding_mask is not None and key_padding_mask[:, -1].any(): + # right padding + batch_size = key_padding_mask.size(0) + residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0) + src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True) + src_lens = src_lens.expand( + batch_size, self.num_heads + ).contiguous().view(-1, 1) + src_lens = src_lens.expand(-1, tgt_len).contiguous() + # add back the last value + residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1) + alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals) + else: + residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0) + alpha[:, :, -1] = residuals + + if torch.isnan(alpha).any(): + # Something is wrong + raise RuntimeError("NaN in alpha.") return alpha - def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state): + def expected_alignment_infer( + self, p_choose, encoder_padding_mask, incremental_state + ): + # TODO modify this function """ Calculating mo alignment for MMA during inference time ============================================================ Expected input size p_choose: bsz * num_heads, tgt_len, src_len - key_padding_mask: bsz * src_len incremental_state: dict + encodencoder_padding_mask: bsz * src_len """ # p_choose: bsz * self.num_heads, src_len bsz_num_heads, tgt_len, src_len = p_choose.size() @@ -166,7 +198,8 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # prev_monotonic_step: bsz, num_heads bsz = bsz_num_heads // self.num_heads prev_monotonic_step = monotonic_cache.get( - "step", p_choose.new_zeros([bsz, self.num_heads]).long() + "head_step", + p_choose.new_zeros([bsz, self.num_heads]).long() ) bsz, num_heads = prev_monotonic_step.size() assert num_heads == self.num_heads @@ -175,8 +208,9 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state # p_choose: bsz, num_heads, src_len p_choose = p_choose.view(bsz, num_heads, src_len) - if key_padding_mask is not None: - src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long() + if encoder_padding_mask is not None: + src_lengths = src_len - \ + encoder_padding_mask.sum(dim=1, keepdim=True).long() else: src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len @@ -186,16 +220,16 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step = prev_monotonic_step step_offset = 0 - if key_padding_mask is not None: - if key_padding_mask[:, 0].any(): + if encoder_padding_mask is not None: + if encoder_padding_mask[:, 0].any(): # left_pad_source = True: - step_offset = key_padding_mask.sum(dim=-1, keepdim=True) + step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True) max_steps = src_lengths - 1 if self.mass_preservation else src_lengths # finish_read: bsz, num_heads finish_read = new_monotonic_step.eq(max_steps) - + p_choose_i = 1 while finish_read.sum().item() < bsz * self.num_heads: # p_choose: bsz * self.num_heads, src_len # only choose the p at monotonic steps @@ -224,23 +258,32 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state new_monotonic_step += action finish_read = new_monotonic_step.eq(max_steps) | (action == 0) - # finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read - monotonic_cache["step"] = new_monotonic_step + + monotonic_cache["head_step"] = new_monotonic_step + # Whether a head is looking for new input + monotonic_cache["head_read"] = ( + new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5) + ) # alpha: bsz * num_heads, 1, src_len # new_monotonic_step: bsz, num_heads - alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter( - 1, - (step_offset + new_monotonic_step) - .view(bsz * self.num_heads, 1) - .clamp(0, src_len - 1), - 1, + alpha = ( + p_choose + .new_zeros([bsz * self.num_heads, src_len]) + .scatter( + 1, + (step_offset + new_monotonic_step) + .view(bsz * self.num_heads, 1).clamp(0, src_len - 1), + 1 + ) ) if not self.mass_preservation: alpha = alpha.masked_fill( - (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0 + (new_monotonic_step == max_steps) + .view(bsz * self.num_heads, 1), + 0 ) alpha = alpha.unsqueeze(1) @@ -249,18 +292,28 @@ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state return alpha + def _get_monotonic_buffer(self, incremental_state): + return utils.get_incremental_state( + self, + incremental_state, + 'monotonic', + ) or {} + + def _set_monotonic_buffer(self, incremental_state, buffer): + utils.set_incremental_state( + self, + incremental_state, + 'monotonic', + buffer, + ) + def v_proj_output(self, value): raise NotImplementedError def forward( - self, - query, - key, - value, - key_padding_mask=None, - incremental_state=None, - *args, - **kwargs, + self, query, key, value, + key_padding_mask=None, attn_mask=None, incremental_state=None, + need_weights=True, static_kv=False, *args, **kwargs ): tgt_len, bsz, embed_dim = query.size() @@ -268,26 +321,31 @@ def forward( # stepwise prob # p_choose: bsz * self.num_heads, tgt_len, src_len - p_choose = self.p_choose(query, key, key_padding_mask) + p_choose = self.p_choose( + query, key, key_padding_mask, incremental_state, + ) # expected alignment alpha # bsz * self.num_heads, tgt_len, src_len if incremental_state is not None: alpha = self.expected_alignment_infer( - p_choose, key_padding_mask, incremental_state - ) + p_choose, key_padding_mask, incremental_state) else: - alpha = self.expected_alignment_train(p_choose, key_padding_mask) + alpha = self.expected_alignment_train( + p_choose, key_padding_mask) # expected attention beta # bsz * self.num_heads, tgt_len, src_len beta = self.expected_attention( - alpha, query, key, value, key_padding_mask, incremental_state + alpha, query, key, value, + key_padding_mask, attn_mask, + incremental_state ) attn_weights = beta v_proj = self.v_proj_output(value) + attn = torch.bmm(attn_weights.type_as(v_proj), v_proj) attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) @@ -298,67 +356,17 @@ def forward( alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) - return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose} - - def reorder_incremental_state(self, incremental_state, new_order): - """Reorder buffered internal state (for incremental generation).""" - super().reorder_incremental_state(incremental_state, new_order) - input_buffer = self._get_monotonic_buffer(incremental_state) - if input_buffer is not None: - for k in input_buffer.keys(): - input_buffer[k] = input_buffer[k].index_select(0, new_order) - self._set_monotonic_buffer(incremental_state, input_buffer) - - def _get_monotonic_buffer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def _set_monotonic_buffer(self, incremental_state, buffer): - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - buffer, - ) - - def get_pointer(self, incremental_state): - return ( - utils.get_incremental_state( - self, - incremental_state, - "monotonic", - ) - or {} - ) - - def get_fastest_pointer(self, incremental_state): - return self.get_pointer(incremental_state)["step"].max(0)[0] - - def set_pointer(self, incremental_state, p_choose): - curr_pointer = self.get_pointer(incremental_state) - if len(curr_pointer) == 0: - buffer = torch.zeros_like(p_choose) - else: - buffer = self.get_pointer(incremental_state)["step"] - - buffer += (p_choose < 0.5).type_as(buffer) - - utils.set_incremental_state( - self, - incremental_state, - "monotonic", - {"step": buffer}, - ) + return attn, { + "alpha": alpha, + "beta": beta, + "p_choose": p_choose, + } @register_monotonic_attention("hard_aligned") -class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention): +class MonotonicMultiheadAttentionHardAligned( + MonotonicAttention, MultiheadAttention +): def __init__(self, args): MultiheadAttention.__init__( self, @@ -392,39 +400,36 @@ def input_projections(self, query, key, value, name): bsz = query.size(1) q = self.q_in_proj[name](query) q *= self.scaling - q = ( - q.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + q = q.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: q = None if key is not None: bsz = key.size(1) k = self.k_in_proj[name](key) - k = ( - k.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + k = k.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: k = None if value is not None: bsz = value.size(1) v = self.v_in_proj[name](value) - v = ( - v.contiguous() - .view(-1, bsz * self.num_heads, self.head_dim) - .transpose(0, 1) - ) + v = v.contiguous().view( + -1, bsz * self.num_heads, self.head_dim + ).transpose(0, 1) else: v = None return q, k, v - def p_choose(self, query, key, key_padding_mask=None): + def p_choose( + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args + ): """ Calculating step wise prob for reading and writing 1 to read, 0 to write @@ -440,7 +445,9 @@ def p_choose(self, query, key, key_padding_mask=None): """ # prepare inputs - q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic") + q_proj, k_proj, _ = self.input_projections( + query, key, None, "monotonic" + ) # attention energy attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask) @@ -473,7 +480,9 @@ def v_proj_output(self, value): @register_monotonic_attention("infinite_lookback") -class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard): +class MonotonicMultiheadAttentionInfiniteLookback( + MonotonicMultiheadAttentionHardAligned +): def __init__(self, args): super().__init__(args) self.init_soft_attention() @@ -498,30 +507,33 @@ def init_soft_attention(self): nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) def expected_attention( - self, alpha, query, key, value, key_padding_mask, incremental_state + self, alpha, query, key, value, + key_padding_mask, attn_mask, incremental_state ): # monotonic attention, we will calculate milk here bsz_x_num_heads, tgt_len, src_len = alpha.size() bsz = int(bsz_x_num_heads / self.num_heads) q, k, _ = self.input_projections(query, key, None, "soft") - soft_energy = self.attn_energy(q, k, key_padding_mask) + soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask) - assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len] + assert list(soft_energy.size()) == \ + [bsz, self.num_heads, tgt_len, src_len] soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len) if incremental_state is not None: monotonic_cache = self._get_monotonic_buffer(incremental_state) - monotonic_step = monotonic_cache["step"] + 1 + monotonic_length = monotonic_cache["head_step"] + 1 step_offset = 0 if key_padding_mask is not None: if key_padding_mask[:, 0].any(): # left_pad_source = True: step_offset = key_padding_mask.sum(dim=-1, keepdim=True) - monotonic_step += step_offset + monotonic_length += step_offset mask = lengths_to_mask( - monotonic_step.view(-1), soft_energy.size(2), 1 + monotonic_length.view(-1), + soft_energy.size(2), 1 ).unsqueeze(1) soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf")) @@ -531,84 +543,82 @@ def expected_attention( beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2) else: - # bsz * num_heads, tgt_len, src_len soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] - exp_soft_energy = torch.exp(soft_energy) - exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2) + exp_soft_energy = torch.exp(soft_energy) + self.eps + inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2)) + + beta = ( + exp_soft_energy + * torch.cumsum(inner_items.flip(dims=[2]), dim=2) + .flip(dims=[2]) + ) + + beta = beta.view(bsz, self.num_heads, tgt_len, src_len) if key_padding_mask is not None: - if key_padding_mask.any(): - exp_soft_energy_cumsum = ( - exp_soft_energy_cumsum.view( - -1, self.num_heads, tgt_len, src_len - ) - .masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps - ) - .view(-1, tgt_len, src_len) - ) - - inner_items = alpha / exp_soft_energy_cumsum - - beta = exp_soft_energy * torch.cumsum( - inner_items.flip(dims=[2]), dim=2 - ).flip(dims=[2]) + beta = beta.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).bool(), 0) + beta = beta / beta.sum(dim=3, keepdim=True) + beta = beta.view(bsz * self.num_heads, tgt_len, src_len) beta = self.dropout_module(beta) - assert not torch.isnan(beta).any(), "NaN detected in beta." + if torch.isnan(beta).any(): + # Something is wrong + raise RuntimeError("NaN in beta.") return beta @register_monotonic_attention("waitk") -class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback): +class MonotonicMultiheadAttentionWaitK( + MonotonicMultiheadAttentionInfiniteLookback +): def __init__(self, args): super().__init__(args) self.q_in_proj["soft"] = self.q_in_proj["monotonic"] self.k_in_proj["soft"] = self.k_in_proj["monotonic"] self.waitk_lagging = args.waitk_lagging - assert ( - self.waitk_lagging > 0 - ), f"Lagging has to been larger than 0, get {self.waitk_lagging}." + assert self.waitk_lagging > 0, ( + f"Lagging has to been larger than 0, get {self.waitk_lagging}." + ) @staticmethod def add_args(parser): super( - MonotonicMultiheadAttentionWaitk, - MonotonicMultiheadAttentionWaitk, + MonotonicMultiheadAttentionWaitK, + MonotonicMultiheadAttentionWaitK, ).add_args(parser) parser.add_argument( - "--waitk-lagging", type=int, required=True, help="Wait k lagging" + "--waitk-lagging", type=int, required=True, help="Wait K lagging" ) def p_choose( - self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None + self, query, key, key_padding_mask=None, + incremental_state=None, *extra_args ): """ query: bsz, tgt_len key: bsz, src_len key_padding_mask: bsz, src_len """ + if incremental_state is not None: + # Retrieve target length from incremental states + # For inference the length of query is always 1 + tgt_len = int(incremental_state["steps"]["tgt"]) + else: + tgt_len, bsz, _ = query.size() + src_len, bsz, _ = key.size() - tgt_len, bsz, _ = query.size() + p_choose = query.new_ones(bsz, tgt_len, src_len) p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1) p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1) - if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any(): - # Left pad source - # add -1 to the end - p_choose = p_choose.masked_fill( - key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1 - ) - p_choose = convert_padding_direction( - p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True - ) - p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query) - # remove -1 - p_choose[p_choose.eq(-1)] = 0 + if incremental_state is not None: + p_choose = p_choose[:, -1:] + tgt_len = 1 # Extend to each head p_choose = ( diff --git a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py index 442b7d487d..e6e1850a18 100644 --- a/examples/simultaneous_translation/modules/monotonic_transformer_layer.py +++ b/examples/simultaneous_translation/modules/monotonic_transformer_layer.py @@ -26,11 +26,19 @@ def __init__( add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) + + assert args.simul_type is not None, "A --simul-type is needed." + self.encoder_attn = build_monotonic_attention(args) self.encoder_attn_layer_norm = LayerNorm( self.embed_dim, export=getattr(args, "char_inputs", False) ) + def get_head_steps(self, incremental_state): + return self.encoder_attn._get_monotonic_buffer(incremental_state).get( + "head_step" + ) + def prune_incremental_state(self, incremental_state): def prune(module): input_buffer = module._get_input_buffer(incremental_state) diff --git a/examples/simultaneous_translation/utils/data_utils.py b/examples/simultaneous_translation/utils/data_utils.py new file mode 100644 index 0000000000..cc4729e63c --- /dev/null +++ b/examples/simultaneous_translation/utils/data_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calc_mean_invstddev(feature): + if len(feature.size()) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = feature.mean(0) + var = feature.var(0) + # avoid division by ~zero + eps = 1e-8 + if (var < eps).any(): + return mean, 1.0 / (torch.sqrt(var) + eps) + return mean, 1.0 / torch.sqrt(var) + + +def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def lengths_to_encoder_padding_mask(lengths, batch_first=False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = 0 for t < lengths[b] and 1 otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) >= lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +def encoder_padding_mask_to_lengths( + encoder_padding_mask, max_lengths, batch_size, device +): + """ + convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor + + Conventionally, encoder output contains a encoder_padding_mask, which is + a 2-D mask in a shape (T, B), whose (t, b) element indicate whether + encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we + need to convert this mask tensor to a 1-D tensor in shape (B, ), where + [b] denotes the valid length of b-th sequence + + Args: + encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, + indicating all are valid + Return: + seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the + number of valid elements of b-th sequence + + max_lengths: maximum length of all sequence, if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(0) + + batch_size: batch size; if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(1) + + device: which device to put the result on + """ + if encoder_padding_mask is None: + return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) + + assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" + assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" + + return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/examples/speech_recognition/hydra/infer.py b/examples/speech_recognition/hydra/infer.py index 6afa066f25..b1c985bc0d 100644 --- a/examples/speech_recognition/hydra/infer.py +++ b/examples/speech_recognition/hydra/infer.py @@ -10,10 +10,9 @@ import os import shutil import sys -from argparse import Namespace from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import editdistance import torch @@ -26,7 +25,6 @@ CommonEvalConfig, DatasetConfig, DistributedTrainingConfig, FairseqDataclass, GenerationConfig) -from fairseq.dataclass.initialize import hydra_init from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.progress_bar import BaseProgressBar from fairseq.models.fairseq_model import FairseqModel diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 5a582c54af..f4efbf39c8 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -144,11 +144,11 @@ def process_predictions( print( "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] ) - # only score top hypothesis - if not args.quiet: - logger.debug("HYPO:" + hyp_words) - logger.debug("TARGET:" + tgt_words) - logger.debug("___________________") + + if not args.quiet: + logger.info("HYPO:" + hyp_words) + logger.info("TARGET:" + tgt_words) + logger.info("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() @@ -216,7 +216,6 @@ def main(args, task=None, model_state=None): use_cuda = torch.cuda.is_available() and not args.cpu - logger.info("| decoding with criterion {}".format(args.criterion)) task = tasks.setup_task(args) @@ -227,7 +226,7 @@ def main(args, task=None, model_state=None): task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) - models, saved_cfg = checkpoint_utils.load_model_ensemble( + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( utils.split_paths(args.path), arg_overrides=ast.literal_eval(args.model_overrides), task=task, diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py index 706d9f1433..8b158293a0 100644 --- a/examples/speech_recognition/w2l_decoder.py +++ b/examples/speech_recognition/w2l_decoder.py @@ -431,7 +431,7 @@ def __init__(self, args, tgt_dict): self.silence, self.blank, self.unk_word, - self.asg_transitions, + [], self.unit_lm, ) else: diff --git a/examples/speech_to_text/README.md b/examples/speech_to_text/README.md index 4b6f89d105..f639d300d3 100644 --- a/examples/speech_to_text/README.md +++ b/examples/speech_to_text/README.md @@ -36,6 +36,9 @@ audio paths (one per line) as inputs. - [Speech-to-Text Translation (ST) on CoVoST 2](docs/covost_example.md) +- [Speech-to-Text Translation (ST) on Multilingual TEDx](docs/mtedx_example.md) +- [Simultaneous Speech-to-Text Translation (SimulST) on MuST-C](docs/simulst_mustc_example.md) + ## Updates - 02/04/2021: Added interactive decoding (`fairseq-interactive`) support. Examples: [ASR (LibriSpeech)](docs/librispeech_example.md#interactive-decoding) diff --git a/examples/speech_to_text/data_utils.py b/examples/speech_to_text/data_utils.py index 0d7c034419..3f96ffc427 100644 --- a/examples/speech_to_text/data_utils.py +++ b/examples/speech_to_text/data_utils.py @@ -14,7 +14,10 @@ import numpy as np import pandas as pd import sentencepiece as sp -from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank +from fairseq.data.audio.audio_utils import ( + _convert_to_mono, _get_kaldi_fbank, _get_torchaudio_fbank +) +import torch from tqdm import tqdm @@ -66,7 +69,7 @@ def gen_vocab( def extract_fbank_features( - waveform, + waveform: torch.FloatTensor, sample_rate: int, output_path: Optional[Path] = None, n_mel_bins: int = 80, @@ -75,8 +78,9 @@ def extract_fbank_features( if output_path is not None and output_path.is_file() and not overwrite: return - _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers - _waveform = _waveform.squeeze().numpy() + _waveform = _convert_to_mono(waveform, sample_rate) + _waveform = _waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + _waveform = _waveform.numpy() features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins) if features is None: @@ -126,7 +130,9 @@ def gen_config_yaml( specaugment_policy: str = "lb", prepend_tgt_lang_tag: bool = False, sampling_alpha: float = 1.0, - audio_root: str = "" + audio_root: str = "", + cmvn_type: str = "utterance", + gcmvn_path: Optional[Path] = None, ): manifest_root = manifest_root.absolute() writer = S2TDataConfigWriter(manifest_root / yaml_filename) @@ -151,8 +157,19 @@ def gen_config_yaml( if prepend_tgt_lang_tag: writer.set_prepend_tgt_lang_tag(True) writer.set_sampling_alpha(sampling_alpha) - writer.set_feature_transforms("_train", ["utterance_cmvn", "specaugment"]) - writer.set_feature_transforms("*", ["utterance_cmvn"]) + + if cmvn_type not in ["global", "utterance"]: + raise NotImplementedError + + writer.set_feature_transforms("_train", [f"{cmvn_type}_cmvn", "specaugment"]) + writer.set_feature_transforms("*", [f"{cmvn_type}_cmvn"]) + + if cmvn_type == "global": + assert gcmvn_path is not None, ( + 'Please provide path of global cmvn file.' + ) + writer.set_global_cmvn(gcmvn_path) + if len(audio_root) > 0: writer.set_audio_root(audio_root) writer.flush() @@ -206,6 +223,16 @@ def filter_manifest_df( return df[valid] +def cal_gcmvn_stats(features_list): + features = np.concatenate(features_list) + square_sums = (features ** 2).sum(axis=0) + mean = features.mean(axis=0) + features = np.subtract(features, mean) + var = square_sums / features.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-8)) + return {"mean": mean.astype("float32"), "std": std.astype("float32")} + + class S2TDataConfigWriter(object): DEFAULT_VOCAB_FILENAME = "dict.txt" DEFAULT_INPUT_FEAT_PER_CHANNEL = 80 @@ -297,6 +324,9 @@ def set_input_feat_per_channel(self, input_feat_per_channel: int = 80): def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]): self.config["bpe_tokenizer"] = bpe_tokenizer + def set_global_cmvn(self, stats_npz_path: str): + self.config["stats_npz_path"] = stats_npz_path + def set_feature_transforms(self, split: str, transforms: List[str]): if "transforms" not in self.config: self.config["transforms"] = {} diff --git a/examples/speech_to_text/docs/mtedx_example.md b/examples/speech_to_text/docs/mtedx_example.md new file mode 100644 index 0000000000..25b4556aff --- /dev/null +++ b/examples/speech_to_text/docs/mtedx_example.md @@ -0,0 +1,200 @@ +[[Back]](..) + +# S2T Example: Speech Translation (ST) on Multilingual TEDx + +[Multilingual TEDx](https://arxiv.org/abs/2102.01757) is multilingual corpus for speech recognition and +speech translation. The data is derived from TEDx talks in 8 source languages +with translations to a subset of 5 target languages. + +## Data Preparation +[Download](http://openslr.org/100/) and unpack Multilingual TEDx data to a path +`${MTEDX_ROOT}/${LANG_PAIR}`, then preprocess it with +```bash +# additional Python packages for S2T data processing/model training +pip install pandas torchaudio soundfile sentencepiece + +# Generate TSV manifests, features, vocabulary +# and configuration for each language +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr \ + --vocab-type unigram --vocab-size 1000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st \ + --vocab-type unigram --vocab-size 1000 + +# Add vocabulary and configuration for joint data +# (based on the manifests and features generated above) +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task asr --joint \ + --vocab-type unigram --vocab-size 8000 +python examples/speech_to_text/prep_mtedx_data.py \ + --data-root ${MTEDX_ROOT} --task st --joint \ + --vocab-type unigram --vocab-size 8000 +``` +The generated files (manifest, features, vocabulary and data configuration) will be added to +`${MTEDX_ROOT}/${LANG_PAIR}` (per-language data) and `MTEDX_ROOT` (joint data). + + +## ASR +#### Training +Spanish as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset valid_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For joint model (using ASR data from all 8 languages): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml \ + --train-subset train_es-es_asr,train_fr-fr_asr,train_pt-pt_asr,train_it-it_asr,train_ru-ru_asr,train_el-el_asr,train_ar-ar_asr,train_de-de_asr \ + --valid-subset valid_es-es_asr,valid_fr-fr_asr,valid_pt-pt_asr,valid_it-it_asr,valid_ru-ru_asr,valid_el-el_asr,valid_ar-ar_asr,valid_de-de_asr \ + --save-dir ${MULTILINGUAL_ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 +``` +where `MULTILINGUAL_ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs +with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-es \ + --config-yaml config_asr.yaml --gen-subset test --task speech_to_text \ + --path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe + +# For models trained on joint data +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ASR_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANG in es fr pt it ru el ar de; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_asr.yaml --gen-subset test_${LANG}-${LANG}_asr --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct --remove-bpe +done +``` +#### Results +| Data | --arch | Params | Es | Fr | Pt | It | Ru | El | Ar | De | +|--------------|--------------------|--------|------|------|------|------|------|-------|-------|-------| +| Monolingual | s2t_transformer_xs | 10M | 46.4 | 45.6 | 54.8 | 48.0 | 74.7 | 109.5 | 104.4 | 111.1 | + + +## ST +#### Training +Es-En as example: +```bash +fairseq-train ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset valid_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_xs --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 +``` +For multilingual model (all 12 directions): +```bash +fairseq-train ${MTEDX_ROOT} \ + --config-yaml config_st.yaml \ + --train-subset train_el-en_st,train_es-en_st,train_es-fr_st,train_es-it_st,train_es-pt_st,train_fr-en_st,train_fr-es_st,train_fr-pt_st,train_it-en_st,train_it-es_st,train_pt-en_st,train_pt-es_st,train_ru-en_st \ + --valid-subset valid_el-en_st,valid_es-en_st,valid_es-fr_st,valid_es-it_st,valid_es-pt_st,valid_fr-en_st,valid_fr-es_st,valid_fr-pt_st,valid_it-en_st,valid_it-es_st,valid_pt-en_st,valid_pt-es_st,valid_ru-en_st \ + --save-dir ${MULTILINGUAL_ST_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-epoch 200 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --dropout 0.3 --label-smoothing 0.1 \ + --skip-invalid-size-inputs-valid-test \ + --keep-last-epochs 10 --update-freq 8 --patience 10 \ + --ignore-prefix-size 1 \ + --load-pretrained-encoder-from ${PRETRAINED_ENCODER} +``` +where `ST_SAVE_DIR` (`MULTILINGUAL_ST_SAVE_DIR`) is the checkpoint root path. The ST encoder is pre-trained by ASR +for faster training and better performance: `--load-pretrained-encoder-from <(JOINT_)ASR checkpoint path>`. We set +`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU. +For multilingual models, we prepend target language ID token as target BOS, which should be excluded from +the training loss via `--ignore-prefix-size 1`. + +#### Inference & Evaluation +Average the last 10 checkpoints and evaluate on the `test` split: +```bash +CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt +python scripts/average_checkpoints.py \ + --inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${MTEDX_ROOT}/es-en \ + --config-yaml config_st.yaml --gen-subset test --task speech_to_text \ + --path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 50000 --beam 5 --scoring sacrebleu --remove-bpe + +# For multilingual models +python scripts/average_checkpoints.py \ + --inputs ${MULTILINGUAL_ST_SAVE_DIR} --num-epoch-checkpoints 10 \ + --output "${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME}" + +for LANGPAIR in es-en es-fr es-pt fr-en fr-es fr-pt pt-en pt-es it-en it-es ru-en el-en; do + fairseq-generate ${MTEDX_ROOT} \ + --config-yaml config_st.yaml --gen-subset test_${LANGPAIR}_st --task speech_to_text \ + --prefix-size 1 --path ${MULTILINGUAL_ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --scoring sacrebleu --remove-bpe +done +``` +For multilingual models, we force decoding from the target language ID token (as BOS) via `--prefix-size 1`. + +#### Results +| Data | --arch | Params | Es-En | Es-Pt | Es-Fr | Fr-En | Fr-Es | Fr-Pt | Pt-En | Pt-Es | It-En | It-Es | Ru-En | El-En | +|--------------|--------------------|-----|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| +| Bilingual | s2t_transformer_xs | 10M | 7.0 | 12.2 | 1.7 | 8.9 | 10.6 | 7.9 | 8.1 | 8.7 | 6.4 | 1.0 | 0.7 | 0.6 | +| Multilingual | s2t_transformer_s | 31M | 12.3 | 17.4 | 6.1 | 12.0 | 13.6 | 13.2 | 12.0 | 13.7 | 10.7 | 13.1 | 0.6 | 0.8 | + + +## Citation +Please cite as: +``` +@misc{salesky2021mtedx, + title={Multilingual TEDx Corpus for Speech Recognition and Translation}, + author={Elizabeth Salesky and Matthew Wiesner and Jacob Bremerman and Roldano Cattoni and Matteo Negri and Marco Turchi and Douglas W. Oard and Matt Post}, + year={2021}, +} + +@inproceedings{wang2020fairseqs2t, + title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, + author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino}, + booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations}, + year = {2020}, +} + +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` + +[[Back]](..) diff --git a/examples/speech_to_text/docs/mustc_example.md b/examples/speech_to_text/docs/mustc_example.md index 7628dc77ef..79df0aafdc 100644 --- a/examples/speech_to_text/docs/mustc_example.md +++ b/examples/speech_to_text/docs/mustc_example.md @@ -11,7 +11,7 @@ `${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with ```bash # additional Python packages for S2T data processing/model training -pip install pandas torchaudio sentencepiece +pip install pandas torchaudio soundfile sentencepiece # Generate TSV manifests, features, vocabulary # and configuration for each language diff --git a/examples/speech_to_text/docs/simulst_mustc_example.md b/examples/speech_to_text/docs/simulst_mustc_example.md new file mode 100644 index 0000000000..3452806a1c --- /dev/null +++ b/examples/speech_to_text/docs/simulst_mustc_example.md @@ -0,0 +1,123 @@ +# Simultaneous Speech Translation (SimulST) on MuST-C + +This is a tutorial of training and evaluating a transformer *wait-k* simultaneous model on MUST-C English-Germen Dataset, from [SimulMT to SimulST: Adapting Simultaneous Text Translation to End-to-End Simultaneous Speech Translation](https://www.aclweb.org/anthology/2020.aacl-main.58.pdf). + +[MuST-C](https://www.aclweb.org/anthology/N19-1202) is multilingual speech-to-text translation corpus with 8-language translations on English TED talks. + +## Data Preparation +[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path +`${MUSTC_ROOT}/en-${TARGET_LANG_ID}`, then preprocess it with +```bash +# Additional Python packages for S2T data processing/model training +pip install pandas torchaudio sentencepiece + +# Generate TSV manifests, features, vocabulary, +# global cepstral and mean estimation, +# and configuration for each language +cd fairseq + +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task asr \ + --vocab-type unigram --vocab-size 10000 \ + --cmvn-type global + +python examples/speech_to_text/prep_mustc_data.py \ + --data-root ${MUSTC_ROOT} --task st \ + --vocab-type unigram --vocab-size 10000 \ + --cmvn-type global +``` + +## ASR Pretraining +We need a pretrained offline ASR model. Assuming the save directory of the ASR model is `${ASR_SAVE_DIR}` +``` +fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \ + --save-dir ${ASR_SAVE_DIR} --num-workers 4 --max-tokens 40000 --max-update 100000 \ + --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \ + --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \ + --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 +``` +A pretrained ASR checkpoint can be downloaded [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1_en_de_pretrained_asr) + +## Simultaneous Speech Translation Training + +### Wait-K with fixed pre-decision module +Fixed pre-decision indicates that the model operate simultaneous policy on the boundaries of fixed chunks. +Here is a example of fixed pre-decision ratio 7 (the simultaneous decision is made every 7 encoder states) and +a wait-3 policy model. Assuming the save directory is `${ST_SAVE_DIR}` +```bash + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --criterion label_smoothed_cross_entropy \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \ + --task speech_to_text \ + --arch convtransformer_simul_trans_espnet \ + --simul-type waitk_fixed_pre_decision \ + --waitk-lagging 3 \ + --fixed-pre-decision-ratio 7 + +``` +### Monotonic multihead attention with fixed pre-decision module +``` + fairseq-train ${MUSTC_ROOT}/en-de \ + --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \ + --save-dir ${ST_SAVE_DIR} --num-workers 8 \ + --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \ + --warmup-updates 4000 --max-update 100000 --max-tokens 40000 --seed 2 \ + --load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --task speech_to_text \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-avg 0.1 \ + --arch convtransformer_simul_trans_espnet \ + --simul-type infinite_lookback_fixed_pre_decision \ + --fixed-pre-decision-ratio 7 +``` +## Inference & Evaluation +[SimulEval](https://github.com/facebookresearch/SimulEval) is used for evaluation. +The source file is a list of paths of audio files, +while target file is the corresponding translations. +``` +git clone https://github.com/facebookresearch/SimulEval.git +cd SimulEval +pip install -e . + +simuleval \ + --agent examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py + --src-file ${SRC_LIST_OF_AUDIO} + --tgt-file ${TGT_FILE} + --data-bin ${MUSTC_ROOT}/en-de \ + --model-path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --tgt-splitter-type SentencePieceModel \ + --tgt-splitter-path ${MUSTC_ROOT}/en-de/spm.model \ + --scores +``` + +A pretrained checkpoint can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/convtransformer_wait5_pre7), which is a wait-5 model with a pre-decision of 280 ms. The databin (containing dictionary, gcmvn file and sentencepiece model) can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/must_c_v1.0_en_de_databin.tgz). + +The output should be similar as follow: +```bash +{ + "Quality": { + "BLEU": 12.79214535384013 + }, + "Latency": { + "AL": 1669.5778120018108, + "AL_CA": 2077.9027656104813, + "AP": 0.7652936521983029, + "AP_CA": 0.8891561507382866, + "DAL": 2028.1566141735727, + "DAL_CA": 2497.336430059716 + } +} +``` + +The quality is measured by detokenized BLEU. So make sure that the predicted words sent to the server are detokenized. + +The latency metrics are +* Average Proportion +* Average Lagging +* Differentiable Average Lagging +Again they will also be evaluated on detokenized text. diff --git a/examples/speech_to_text/prep_librispeech_data.py b/examples/speech_to_text/prep_librispeech_data.py index 6a6f55ded4..7b08447190 100644 --- a/examples/speech_to_text/prep_librispeech_data.py +++ b/examples/speech_to_text/prep_librispeech_data.py @@ -71,7 +71,7 @@ def process(args): manifest["audio"].append(zip_manifest[sample_id]) duration_ms = int(wav.size(1) / sample_rate * 1000) manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) - manifest["tgt_text"].append(utt) + manifest["tgt_text"].append(utt.lower()) manifest["speaker"].append(spk_id) save_df_to_tsv( pd.DataFrame.from_dict(manifest), out_root / f"{split}.tsv" diff --git a/examples/speech_to_text/prep_mtedx_data.py b/examples/speech_to_text/prep_mtedx_data.py new file mode 100644 index 0000000000..34b1c398c8 --- /dev/null +++ b/examples/speech_to_text/prep_mtedx_data.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +import shutil +from itertools import groupby +from tempfile import NamedTemporaryFile +from typing import Tuple + +import pandas as pd +import soundfile as sf +from examples.speech_to_text.data_utils import ( + create_zip, + extract_fbank_features, + filter_manifest_df, + gen_config_yaml, + gen_vocab, + get_zip_manifest, + load_df_from_tsv, + save_df_to_tsv, +) +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + +from fairseq.data.audio.audio_utils import get_waveform + + +log = logging.getLogger(__name__) + + +MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang"] + + +class mTEDx(Dataset): + """ + Create a Dataset for Multilingual TEDx. + Each item is a tuple of the form: waveform, sample_rate, source utterance, + target utterance, speaker_id, utterance_id + """ + + SPLITS = ["train", "valid", "test"] + LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", "de-de", + "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", "fr-pt", + "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] + + def __init__(self, root: str, lang: str, split: str) -> None: + assert split in self.SPLITS and lang in self.LANGPAIRS + _root = Path(root) / f"{lang}" / "data" / split + wav_root, txt_root = _root / "wav", _root / "txt" + assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() + # Load audio segments + try: + import yaml + except ImportError: + print("Please install PyYAML to load the Multilingual TEDx YAML files") + with open(txt_root / f"{split}.yaml") as f: + segments = yaml.load(f, Loader=yaml.BaseLoader) + # Load source and target utterances + src, tgt = lang.split("-") + for _lang in [src, tgt]: + with open(txt_root / f"{split}.{_lang}") as f: + utterances = [r.strip() for r in f] + assert len(segments) == len(utterances) + for i, u in enumerate(utterances): + segments[i][_lang] = u + # Gather info + self.data = [] + for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): + wav_filename = wav_filename.replace(".wav", ".flac") + wav_path = wav_root / wav_filename + sample_rate = sf.info(wav_path.as_posix()).samplerate + seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) + for i, segment in enumerate(seg_group): + offset = int(float(segment["offset"]) * sample_rate) + n_frames = int(float(segment["duration"]) * sample_rate) + _id = f"{wav_path.stem}_{i}" + self.data.append( + ( + wav_path.as_posix(), + offset, + n_frames, + sample_rate, + segment[src], + segment[tgt], + segment["speaker_id"], + tgt, + _id, + ) + ) + + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str, str]: + wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id = self.data[n] + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) + return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id + + def __len__(self) -> int: + return len(self.data) + + +def process(args): + root = Path(args.data_root).absolute() + for lang in mTEDx.LANGPAIRS: + cur_root = root / f"{lang}" + if not cur_root.is_dir(): + print(f"{cur_root.as_posix()} does not exist. Skipped.") + continue + # Extract features + feature_root = cur_root / "fbank80" + feature_root.mkdir(exist_ok=True) + for split in mTEDx.SPLITS: + print(f"Fetching split {split}...") + dataset = mTEDx(root.as_posix(), lang, split) + print("Extracting log mel filter bank features...") + for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): + extract_fbank_features( + waveform, sample_rate, feature_root / f"{utt_id}.npy" + ) + # Pack features into ZIP + zip_path = cur_root / "fbank80.zip" + print("ZIPing features...") + create_zip(feature_root, zip_path) + print("Fetching ZIP manifest...") + zip_manifest = get_zip_manifest(zip_path) + # Generate TSV manifest + print("Generating manifest...") + train_text = [] + for split in mTEDx.SPLITS: + is_train_split = split.startswith("train") + manifest = {c: [] for c in MANIFEST_COLUMNS} + dataset = mTEDx(args.data_root, lang, split) + for wav, sr, src_utt, tgt_utt, speaker_id, tgt_lang, utt_id in tqdm(dataset): + manifest["id"].append(utt_id) + manifest["audio"].append(zip_manifest[utt_id]) + duration_ms = int(wav.size(1) / sr * 1000) + manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10)) + manifest["tgt_text"].append(src_utt if args.task == "asr" else tgt_utt) + manifest["speaker"].append(speaker_id) + manifest["tgt_lang"].append(tgt_lang) + if is_train_split: + train_text.extend(manifest["tgt_text"]) + df = pd.DataFrame.from_dict(manifest) + df = filter_manifest_df(df, is_train_split=is_train_split) + save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") + # Generate vocab + v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for t in train_text: + f.write(t + "\n") + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="lb", + ) + # Clean up + shutil.rmtree(feature_root) + + +def process_joint(args): + cur_root = Path(args.data_root) + assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \ + "do not have downloaded data available for all languages" + # Generate vocab + vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) + spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" + with NamedTemporaryFile(mode="w") as f: + for lang in mTEDx.LANGPAIRS: + tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv" + df = load_df_from_tsv(tsv_path) + for t in df["tgt_text"]: + f.write(t + "\n") + special_symbols = None + if args.joint: + # Add tgt_lang tags to dict + special_symbols = list({f'' for lang in mTEDx.LANGPAIRS}) + gen_vocab( + Path(f.name), + cur_root / spm_filename_prefix, + args.vocab_type, + args.vocab_size, + special_symbols=special_symbols + ) + # Generate config YAML + gen_config_yaml( + cur_root, + spm_filename_prefix + ".model", + yaml_filename=f"config_{args.task}.yaml", + specaugment_policy="ld", + prepend_tgt_lang_tag=(args.joint), + ) + # Make symbolic links to manifests + for lang in mTEDx.LANGPAIRS: + for split in mTEDx.SPLITS: + src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv" + desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" + if not desc_path.is_symlink(): + os.symlink(src_path, desc_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-root", "-d", required=True, type=str) + parser.add_argument( + "--vocab-type", + default="unigram", + required=True, + type=str, + choices=["bpe", "unigram", "char"], + ), + parser.add_argument("--vocab-size", default=8000, type=int) + parser.add_argument("--task", type=str, choices=["asr", "st"]) + parser.add_argument("--joint", action="store_true", help="") + args = parser.parse_args() + + if args.joint: + process_joint(args) + else: + process(args) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_to_text/prep_mustc_data.py b/examples/speech_to_text/prep_mustc_data.py index 520968401c..0ee204e651 100644 --- a/examples/speech_to_text/prep_mustc_data.py +++ b/examples/speech_to_text/prep_mustc_data.py @@ -13,8 +13,9 @@ from tempfile import NamedTemporaryFile from typing import Tuple +import numpy as np import pandas as pd -import torchaudio +import soundfile as sf from examples.speech_to_text.data_utils import ( create_zip, extract_fbank_features, @@ -24,11 +25,14 @@ get_zip_manifest, load_df_from_tsv, save_df_to_tsv, + cal_gcmvn_stats, ) -from torch import Tensor +import torch from torch.utils.data import Dataset from tqdm import tqdm +from fairseq.data.audio.audio_utils import get_waveform + log = logging.getLogger(__name__) @@ -69,7 +73,7 @@ def __init__(self, root: str, lang: str, split: str) -> None: self.data = [] for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): wav_path = wav_root / wav_filename - sample_rate = torchaudio.info(wav_path.as_posix())[0].rate + sample_rate = sf.info(wav_path.as_posix()).samplerate seg_group = sorted(_seg_group, key=lambda x: x["offset"]) for i, segment in enumerate(seg_group): offset = int(float(segment["offset"]) * sample_rate) @@ -88,9 +92,10 @@ def __init__(self, root: str, lang: str, split: str) -> None: ) ) - def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]: + def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str, str, str, str]: wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n] - waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames) + waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) + waveform = torch.from_numpy(waveform) return waveform, sr, src_utt, tgt_utt, spk_id, utt_id def __len__(self) -> int: @@ -111,10 +116,28 @@ def process(args): print(f"Fetching split {split}...") dataset = MUSTC(root.as_posix(), lang, split) print("Extracting log mel filter bank features...") + if split == 'train' and args.cmvn_type == "global": + print("And estimating cepstral mean and variance stats...") + gcmvn_feature_list = [] + for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): - extract_fbank_features( - waveform, sample_rate, feature_root / f"{utt_id}.npy" + features = extract_fbank_features(waveform, sample_rate) + + np.save( + (feature_root / f"{utt_id}.npy").as_posix(), + features ) + + if split == 'train' and args.cmvn_type == "global": + if len(gcmvn_feature_list) < args.gcmvn_max_num: + gcmvn_feature_list.append(features) + + if split == 'train' and args.cmvn_type == "global": + # Estimate and save cmv + stats = cal_gcmvn_stats(gcmvn_feature_list) + with open(cur_root / "gcmvn.npz", "wb") as f: + np.savez(f, mean=stats["mean"], std=stats["std"]) + # Pack features into ZIP zip_path = cur_root / "fbank80.zip" print("ZIPing features...") @@ -158,6 +181,11 @@ def process(args): spm_filename_prefix + ".model", yaml_filename=f"config_{args.task}.yaml", specaugment_policy="lb", + cmvn_type=args.cmvn_type, + gcmvn_path=( + cur_root / "gcmvn.npz" if args.cmvn_type == "global" + else None + ), ) # Clean up shutil.rmtree(feature_root) @@ -216,6 +244,14 @@ def main(): parser.add_argument("--vocab-size", default=8000, type=int) parser.add_argument("--task", type=str, choices=["asr", "st"]) parser.add_argument("--joint", action="store_true", help="") + parser.add_argument("--cmvn-type", default="utterance", + choices=["global", "utterance"], + help="The type of cepstral mean and variance normalization") + parser.add_argument("--gcmvn-max-num", default=150000, type=int, + help=( + "Maximum number of sentences to use to estimate" + "global mean and variance" + )) args = parser.parse_args() if args.joint: diff --git a/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py new file mode 100644 index 0000000000..8b8003e1d5 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py @@ -0,0 +1,364 @@ +import math +import os +import json +import numpy as np +import torch +import torchaudio.compliance.kaldi as kaldi +import yaml +from fairseq import checkpoint_utils, tasks +from fairseq.file_io import PathManager + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent + from simuleval.states import ListEntry, SpeechStates +except ImportError: + print("Please install simuleval 'pip install simuleval'") + +SHIFT_SIZE = 10 +WINDOW_SIZE = 25 +SAMPLE_RATE = 16000 +FEATURE_DIM = 80 +BOW_PREFIX = "\u2581" + + +class OnlineFeatureExtractor: + """ + Extract speech feature on the fly. + """ + + def __init__(self, args): + self.shift_size = args.shift_size + self.window_size = args.window_size + assert self.window_size >= self.shift_size + + self.sample_rate = args.sample_rate + self.feature_dim = args.feature_dim + self.num_samples_per_shift = int(self.shift_size * self.sample_rate / 1000) + self.num_samples_per_window = int(self.window_size * self.sample_rate / 1000) + self.len_ms_to_samples = lambda x: x * self.sample_rate / 1000 + self.previous_residual_samples = [] + self.global_cmvn = args.global_cmvn + + def clear_cache(self): + self.previous_residual_samples = [] + + def __call__(self, new_samples): + samples = self.previous_residual_samples + new_samples + if len(samples) < self.num_samples_per_window: + self.previous_residual_samples = samples + return + + # num_frames is the number of frames from the new segment + num_frames = math.floor( + (len(samples) - self.len_ms_to_samples(self.window_size - self.shift_size)) + / self.num_samples_per_shift + ) + + # the number of frames used for feature extraction + # including some part of thte previous segment + effective_num_samples = int( + num_frames * self.len_ms_to_samples(self.shift_size) + + self.len_ms_to_samples(self.window_size - self.shift_size) + ) + + input_samples = samples[:effective_num_samples] + self.previous_residual_samples = samples[ + num_frames * self.num_samples_per_shift: + ] + + torch.manual_seed(1) + output = kaldi.fbank( + torch.FloatTensor(input_samples).unsqueeze(0), + num_mel_bins=self.feature_dim, + frame_length=self.window_size, + frame_shift=self.shift_size, + ).numpy() + + output = self.transform(output) + + return torch.from_numpy(output) + + def transform(self, input): + if self.global_cmvn is None: + return input + + mean = self.global_cmvn["mean"] + std = self.global_cmvn["std"] + + x = np.subtract(input, mean) + x = np.divide(x, std) + return x + + +class TensorListEntry(ListEntry): + """ + Data structure to store a list of tensor. + """ + + def append(self, value): + + if len(self.value) == 0: + self.value = value + return + + self.value = torch.cat([self.value] + [value], dim=0) + + def info(self): + return { + "type": str(self.new_value_type), + "length": self.__len__(), + "value": "" if type(self.value) is list else self.value.size(), + } + + +class FairseqSimulSTAgent(SpeechAgent): + + speech_segment_size = 40 # in ms, 4 pooling ratio * 10 ms step size + + def __init__(self, args): + super().__init__(args) + + self.eos = DEFAULT_EOS + + self.gpu = getattr(args, "gpu", False) + + self.args = args + + self.load_model_vocab(args) + + if getattr( + self.model.decoder.layers[0].encoder_attn, + 'pre_decision_ratio', + None + ) is not None: + self.speech_segment_size *= ( + self.model.decoder.layers[0].encoder_attn.pre_decision_ratio + ) + + args.global_cmvn = None + if args.config: + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.BaseLoader) + + if "global_cmvn" in config: + args.global_cmvn = np.load(config["global_cmvn"]["stats_npz_path"]) + + if args.global_stats: + with PathManager.open(args.global_stats, "r") as f: + global_cmvn = json.loads(f.read()) + self.global_cmvn = {"mean": global_cmvn["mean"], "std": global_cmvn["stddev"]} + + self.feature_extractor = OnlineFeatureExtractor(args) + + self.max_len = args.max_len + + self.force_finish = args.force_finish + + torch.set_grad_enabled(False) + + def build_states(self, args, client, sentence_id): + # Initialize states here, for example add customized entry to states + # This function will be called at beginning of every new sentence + states = SpeechStates(args, client, sentence_id, self) + self.initialize_states(states) + return states + + def to_device(self, tensor): + if self.gpu: + return tensor.cuda() + else: + return tensor.cpu() + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--model-path', type=str, required=True, + help='path to your pretrained model.') + parser.add_argument("--data-bin", type=str, required=True, + help="Path of data binary") + parser.add_argument("--config", type=str, default=None, + help="Path to config yaml file") + parser.add_argument("--global-stats", type=str, default=None, + help="Path to json file containing cmvn stats") + parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for target text") + parser.add_argument("--tgt-splitter-path", type=str, default=None, + help="Subword splitter model path for target text") + parser.add_argument("--user-dir", type=str, default="examples/simultaneous_translation", + help="User directory for simultaneous translation") + parser.add_argument("--max-len", type=int, default=200, + help="Max length of translation") + parser.add_argument("--force-finish", default=False, action="store_true", + help="Force the model to finish the hypothsis if the source is not finished") + parser.add_argument("--shift-size", type=int, default=SHIFT_SIZE, + help="Shift size of feature extraction window.") + parser.add_argument("--window-size", type=int, default=WINDOW_SIZE, + help="Window size of feature extraction window.") + parser.add_argument("--sample-rate", type=int, default=SAMPLE_RATE, + help="Sample rate") + parser.add_argument("--feature-dim", type=int, default=FEATURE_DIM, + help="Acoustic feature dimension.") + + # fmt: on + return parser + + def set_up_task(self, task_args): + return tasks.setup_task(task_args) + + def load_model_vocab(self, args): + + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + task_args = state["cfg"]["task"] + task_args.data = args.data_bin + + if args.config is not None: + task_args.config_yaml = args.config + + task = self.set_up_task(task_args) + + # build model for ensemble + self.model = task.build_model(state["cfg"]["model"]) + self.model.load_state_dict(state["model"], strict=True) + self.model.eval() + self.model.share_memory() + + if self.gpu: + self.model.cuda() + + # Set dictionary + self.dict = {} + self.dict["tgt"] = task.target_dictionary + + def initialize_states(self, states): + self.feature_extractor.clear_cache() + states.units.source = TensorListEntry() + states.units.target = ListEntry() + states.incremental_states = dict() + + def segment_to_units(self, segment, states): + # Convert speech samples to features + features = self.feature_extractor(segment) + if features is not None: + return [features] + else: + return [] + + def units_to_segment(self, units, states): + # Merge sub word to full word. + if self.model.decoder.dictionary.eos() == units[0]: + return DEFAULT_EOS + + segment = [] + if None in units.value: + units.value.remove(None) + + for index in units: + if index is None: + units.pop() + token = self.model.decoder.dictionary.string([index]) + if token.startswith(BOW_PREFIX): + if len(segment) == 0: + segment += [token.replace(BOW_PREFIX, "")] + else: + for j in range(len(segment)): + units.pop() + + string_to_return = ["".join(segment)] + + if self.model.decoder.dictionary.eos() == units[0]: + string_to_return += [DEFAULT_EOS] + + return string_to_return + else: + segment += [token.replace(BOW_PREFIX, "")] + + if ( + len(units) > 0 + and self.model.decoder.dictionary.eos() == units[-1] + or len(states.units.target) > self.max_len + ): + tokens = [self.model.decoder.dictionary.string([unit]) for unit in units] + return ["".join(tokens).replace(BOW_PREFIX, "")] + [DEFAULT_EOS] + + return None + + def update_model_encoder(self, states): + if len(states.units.source) == 0: + return + src_indices = self.to_device( + states.units.source.value.unsqueeze(0) + ) + src_lengths = self.to_device( + torch.LongTensor([states.units.source.value.size(0)]) + ) + + states.encoder_states = self.model.encoder(src_indices, src_lengths) + torch.cuda.empty_cache() + + def update_states_read(self, states): + # Happens after a read action. + self.update_model_encoder(states) + + def policy(self, states): + if not getattr(states, "encoder_states", None): + return READ_ACTION + + tgt_indices = self.to_device( + torch.LongTensor( + [self.model.decoder.dictionary.eos()] + + [x for x in states.units.target.value if x is not None] + ).unsqueeze(0) + ) + + states.incremental_states["steps"] = { + "src": states.encoder_states["encoder_out"][0].size(0), + "tgt": 1 + len(states.units.target), + } + + states.incremental_states["online"] = not states.finish_read() + + x, outputs = self.model.decoder.forward( + prev_output_tokens=tgt_indices, + encoder_out=states.encoder_states, + incremental_state=states.incremental_states, + ) + + states.decoder_out = x + + states.decoder_out_extra = outputs + + torch.cuda.empty_cache() + + if outputs["action"] == 0: + return READ_ACTION + else: + return WRITE_ACTION + + def predict(self, states): + decoder_states = states.decoder_out + + lprobs = self.model.get_normalized_probs( + [decoder_states[:, -1:]], log_probs=True + ) + + index = lprobs.argmax(dim=-1) + + index = index[0, 0].item() + + if ( + self.force_finish + and index == self.model.decoder.dictionary.eos() + and not states.finish_read() + ): + # If we want to force finish the translation + # (don't stop before finish reading), return a None + # self.model.decoder.clear_cache(states.incremental_states) + index = None + + return index diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py new file mode 100644 index 0000000000..45df5fa227 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +from fairseq import checkpoint_utils, utils, tasks + +from . import DEFAULT_EOS, GET, SEND +from .agent import Agent + + +class SimulTransAgent(Agent): + def __init__(self, args): + # Load Model + self.load_model(args) + + # build word spliter + self.build_word_splitter(args) + + self.max_len = args.max_len + + self.eos = DEFAULT_EOS + + @staticmethod + def add_args(parser): + parser.add_argument( + "--model-path", + type=str, + required=True, + help="path to your pretrained model.", + ) + parser.add_argument( + "--data-bin", type=str, required=True, help="Path of data binary" + ) + parser.add_argument( + "--user-dir", + type=str, + default="example/simultaneous_translation", + help="User directory for simultaneous translation", + ) + parser.add_argument( + "--src-splitter-type", + type=str, + default=None, + help="Subword splitter type for source text", + ) + parser.add_argument( + "--tgt-splitter-type", + type=str, + default=None, + help="Subword splitter type for target text", + ) + parser.add_argument( + "--src-splitter-path", + type=str, + default=None, + help="Subword splitter model path for source text", + ) + parser.add_argument( + "--tgt-splitter-path", + type=str, + default=None, + help="Subword splitter model path for target text", + ) + parser.add_argument( + "--max-len", + type=int, + default=150, + help="Maximum length difference between source and target prediction", + ) + parser.add_argument( + "--model-overrides", + default="{}", + type=str, + metavar="DICT", + help="A dictionary used to override model args at generation " + "that were used during model training", + ) + # fmt: on + return parser + + def load_dictionary(self, task): + raise NotImplementedError + + def load_model(self, args): + args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") + utils.import_user_module(args) + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, json.loads(args.model_overrides) + ) + + saved_args = state["args"] + saved_args.data = args.data_bin + + task = tasks.setup_task(saved_args) + + # build model for ensemble + self.model = task.build_model(saved_args) + self.model.load_state_dict(state["model"], strict=True) + + # Set dictionary + self.load_dictionary(task) + + def init_states(self): + return { + "indices": {"src": [], "tgt": []}, + "tokens": {"src": [], "tgt": []}, + "segments": {"src": [], "tgt": []}, + "steps": {"src": 0, "tgt": 0}, + "finished": False, + "finish_read": False, + "model_states": {}, + } + + def update_states(self, states, new_state): + raise NotImplementedError + + def policy(self, states): + # Read and Write policy + action = None + + while action is None: + if states["finished"]: + # Finish the hypo by sending eos to server + return self.finish_action() + + # Model make decision given current states + decision = self.model.decision_from_states(states) + + if decision == 0 and not self.finish_read(states): + # READ + action = self.read_action(states) + else: + # WRITE + action = self.write_action(states) + + # None means we make decision again but not sending server anything + # This happened when read a buffered token + # Or predict a subword + return action + + def finish_read(self, states): + raise NotImplementedError + + def write_action(self, states): + token, index = self.model.predict_from_states(states) + + if ( + index == self.dict["tgt"].eos() + or len(states["tokens"]["tgt"]) > self.max_len + ): + # Finish this sentence is predict EOS + states["finished"] = True + end_idx_last_full_word = self._target_length(states) + + else: + states["tokens"]["tgt"] += [token] + end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( + states["tokens"]["tgt"] + ) + self._append_indices(states, [index], "tgt") + + if end_idx_last_full_word > states["steps"]["tgt"]: + # Only sent detokenized full words to the server + word = self.word_splitter["tgt"].merge( + states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] + ) + states["steps"]["tgt"] = end_idx_last_full_word + states["segments"]["tgt"] += [word] + + return {"key": SEND, "value": word} + else: + return None + + def read_action(self, states): + return {"key": GET, "value": None} + + def finish_action(self): + return {"key": SEND, "value": DEFAULT_EOS} + + def reset(self): + pass + + def finish_eval(self, states, new_state): + if len(new_state) == 0 and len(states["indices"]["src"]) == 0: + return True + return False + + def _append_indices(self, states, new_indices, key): + states["indices"][key] += new_indices + + def _target_length(self, states): + return len(states["tokens"]["tgt"]) diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index aeb96a14b1..e86a0d13b8 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -55,7 +55,7 @@ Translate ``` CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out -grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out +grep ^H $TMP/fairseq.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/mt.out ``` Post-process @@ -88,7 +88,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out -grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores +grep ^H $TMP/dropout.scoring.out | cut -d- -f2- | sort -n | cut -f2 > $TMP/dropout.scores ``` @@ -112,7 +112,7 @@ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_ --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out -grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_ +grep ^H $TMP/dropout.generation.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/dropout.hypotheses_ sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl -l $TGT_LANG > $TMP/dropout.hypotheses diff --git a/examples/unsupervised_quality_estimation/meteor.py b/examples/unsupervised_quality_estimation/meteor.py index 4a214e794d..2ee0448cf1 100644 --- a/examples/unsupervised_quality_estimation/meteor.py +++ b/examples/unsupervised_quality_estimation/meteor.py @@ -85,19 +85,19 @@ def read_output(meteor_output_path, n_repeats): def main(): parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input") + parser.add_argument("-i", "--infile") parser.add_argument("-n", "--repeat_times", type=int) parser.add_argument("-m", "--meteor") parser.add_argument("-o", "--output") args = parser.parse_args() - translations = read_translations(args.infile, args.repetitions) + translations = read_translations(args.infile, args.repeat_times) sys.stderr.write("\nGenerating input for Meteor...") - ref_path, mt_path = generate_input(translations, args.repetitions) + ref_path, mt_path = generate_input(translations, args.repeat_times) sys.stderr.write("\nRunning Meteor...") out_path = run_meteor(ref_path, mt_path, args.meteor) sys.stderr.write("\nReading output...") - scores = read_output(out_path, args.repetitions) + scores = read_output(out_path, args.repeat_times) sys.stderr.write("\nWriting results...") with open(args.output, "w") as o: for scr in scores: diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 663adf97dc..bfed3913cf 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -149,7 +149,7 @@ To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the t ## Use wav2vec 2.0 with 🤗Transformers: -Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since vesion 4.3. +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.3. Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). @@ -222,6 +222,58 @@ $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 - --max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test ``` +### Run wav2vec2 pre-training on Google Cloud TPUs: + +Wav2Vec2 is now supported on TPUs! It's currently pre-training only. + +#### Using hydra on a v3-8: + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu.yaml +``` + +#### Using command line arguments on a v3-8: + +``` +$ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size 8 --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + +#### Using hydra on a pod slice (v3-N with N > 8): + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu-pod.yaml # edit distributed-world-size accordingly +``` + +#### Using command line arguments on a pod slice (v3-N with N > 8): + + +``` +$ python -m torch_xla.distributed.xla_dist \ + --tpu ${TPUNAME} --conda-env=torch-xla-${TORCH_XLA_VERSION} --env OMP_NUM_THREADS=1 \ + -- \ +python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size ${WORLD_SIZE} --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + ### Extract embeddings from the downstream task data: ``` diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml new file mode 100644 index 0000000000..676c9fe339 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml @@ -0,0 +1,71 @@ +# @package _group_ + +common: + tpu: true + fp16: false + log_format: json + log_interval: 10 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + enable_padding: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 256 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + mask_channel_prob: 0.1 + mask_prob: 0.1 + + feature_grad_mult: 1.0 + diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml new file mode 100644 index 0000000000..c45c4d9117 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml @@ -0,0 +1,71 @@ +# @package _group_ + +common: + tpu: true + fp16: false + log_format: json + log_interval: 10 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + normalize: true + num_batch_buckets: 3 + precompute_mask_indices: true + enable_padding: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 8 + ddp_backend: legacy_ddp + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 256 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + mask_channel_prob: 0.1 + mask_prob: 0.1 + + feature_grad_mult: 1.0 + diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 2f209b6b39..5a98dad2aa 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -21,7 +21,7 @@ ) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig, open_dict +from omegaconf import Container, DictConfig, open_dict, OmegaConf logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir - if cfg.distributed_rank == 0: + if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) @@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() - if not trainer.is_data_parallel_master: + if not trainer.should_save_checkpoint_on_current_rank: return write_timer = meters.StopwatchMeter() @@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - suffix = cfg.checkpoint_suffix or "" + suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 @@ -93,9 +93,17 @@ def is_better(a, b): if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: - assert PathManager.copy( - checkpoints[0], cp, overwrite=True - ), f"Failed to copy {checkpoints[0]} to {cp}" + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" write_timer.stop() logger.info( @@ -157,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = cfg.checkpoint_suffix + suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' @@ -182,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif cfg.model_parallel_size > 1: + elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file @@ -267,8 +275,22 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) - if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: - overwrite_args_by_name(state["cfg"], arg_overrides) + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) state = _upgrade_state_dict(state) return state @@ -349,6 +371,9 @@ def load_model_ensemble_and_task( if task is None: task = tasks.setup_task(cfg.task) + if "task_state" in state: + task.load_state_dict(state["task_state"]) + # build model for ensemble model = task.build_model(cfg.model) @@ -380,7 +405,23 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(obj, f): +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): if isinstance(f, str): with PathManager.open(f, "wb") as h: torch_persistent_save(obj, h) @@ -393,67 +434,6 @@ def torch_persistent_save(obj, f): logger.error(traceback.format_exc()) -def save_state( - filename, - cfg: FairseqConfig, - model_state_dict, - criterion, - optimizer, - lr_scheduler, - num_updates, - optim_history=None, - extra_state=None, - **kwargs, -): - from fairseq import utils - - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - "cfg": cfg, - "args": kwargs.get("args", None), - "model": model_state_dict or {}, - "optimizer_history": optim_history - + [ - { - "criterion_name": criterion.__class__.__name__, - "optimizer_name": optimizer.__class__.__name__, - "lr_scheduler_state": lr_scheduler.state_dict(), - "num_updates": num_updates, - } - ], - "extra_state": extra_state, - } - if utils.has_parameters(criterion): - state_dict["criterion"] = criterion.state_dict() - - if cfg is None: - cfg = state_dict["args"] - assert cfg is not None, "must provide cfg or args" - - if isinstance(cfg, DictConfig): - no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state - else: - no_save_optimizer_state = cfg.no_save_optimizer_state - if not no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - - # keep everything on CPU - state_dict = utils.move_to_cpu(state_dict) - - if PathManager.supports_rename(filename): - # do atomic save - with PathManager.open(filename + ".tmp", "wb") as f: - torch_persistent_save(state_dict, f) - PathManager.rename(filename + ".tmp", filename) - else: - # fallback to non-atomic save - with PathManager.open(filename, "wb") as f: - torch_persistent_save(state_dict, f) - - def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks @@ -494,7 +474,7 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( + if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions @@ -547,15 +527,39 @@ def _upgrade_state_dict(state): if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] # convert task data arg to a string instead of List[string] - if hasattr(state["args"], "data") and isinstance(state["args"].data, list) and len(state["args"].data) > 0: + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): state["args"].data = state["args"].data[0] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: - with open_dict(state["cfg"]): + cfg = state["cfg"] + with open_dict(cfg): # any upgrades for Hydra-based configs - pass + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = "hard" + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" return state diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py similarity index 86% rename from examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py rename to fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index 761cfe61a1..aa3dba31e2 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -53,8 +53,28 @@ def add_args(parser): LatencyAugmentedLabelSmoothedCrossEntropyCriterion, LatencyAugmentedLabelSmoothedCrossEntropyCriterion, ).add_args(parser) - """Add criterion-specific arguments to the parser.""" # fmt: off + + """Add criterion-specific arguments to the parser.""" + parser.add_argument( + "--label-smoothing", + default=0.0, + type=float, + metavar="D", + help="epsilon for label smoothing, 0 means no label smoothing", + ) + parser.add_argument( + "--ignore_prefix_size", + default=0, + type=int, + help="ignore first N tokens", + ) + parser.add_argument( + "--report-accuracy", + default=False, + type=bool, + help="report accuracy metric", + ) parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D', help="Average loss weight") parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D', diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 859177f2b6..f682508cb1 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -31,7 +31,7 @@ class Wav2VecCriterionConfig(FairseqDataclass): default_factory=lambda: [], metadata={"help": "output keys to log"}, ) - +from fairseq.utils import index_put, is_xla_tensor @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): @@ -52,7 +52,9 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) + self.xla = is_xla_tensor(logits) + # XXX: handle weights on xla. weights = None if hasattr(model, "get_target_weights") and not self.infonce: weights = model.get_target_weights(target, net_output) @@ -61,21 +63,31 @@ def forward(self, model, sample, reduce=True): losses = [] + reduction = "none" if ((not reduce) or self.xla) else "sum" if self.infonce: - loss = F.cross_entropy( - logits, - target, - reduction="sum" if reduce else "none", - ) + loss = F.cross_entropy(logits, target, reduction=reduction) else: loss = F.binary_cross_entropy_with_logits( - logits, - target.float(), - weights, - reduction="sum" if reduce else "none", + logits, target.float(), weights, reduction=reduction + ) + + if self.xla: + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we use mask indices to adjust loss. + mi = ( + sample['net_input']['mask_indices'] + .transpose(0, 1) # logits are transposed in `model.get_logits` + .reshape(logits.size(0)) ) + loss = (loss * mi).sum() if reduce else (loss * mi) - sample_size = target.numel() if self.infonce else target.long().sum().item() + if 'sample_size' in sample and self.infonce: + sample_size = sample['sample_size'] + elif 'mask_indices' in sample['net_input']: + sample_size = sample['net_input']['mask_indices'].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss.detach().clone()) if self.loss_weights is not None: @@ -95,7 +107,7 @@ def forward(self, model, sample, reduce=True): losses.append(p) logging_output = { - "loss": loss.item() if reduce else loss, + "loss": loss.item() if (reduce and not self.xla) else loss.detach(), "ntokens": sample_size, "nsentences": sample["id"].numel(), "sample_size": sample_size, @@ -111,11 +123,14 @@ def forward(self, model, sample, reduce=True): if not self.training: logging_output["target"] = target.cpu().numpy() elif lk in net_output: - logging_output[lk] = float(net_output[lk]) + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f"loss_{i}"] = l.item() + logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach() if self.infonce: with torch.no_grad(): @@ -126,9 +141,15 @@ def forward(self, model, sample, reduce=True): assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 - both = max & min - corr = max.long().sum().item() - both.long().sum().item() - count = max.numel() + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count @@ -188,11 +209,15 @@ def reduce_metrics(logging_outputs) -> None: else: metrics.log_scalar(k, val / len(logging_outputs), round=3) - @staticmethod - def logging_outputs_can_be_summed() -> bool: + # FIXME: revert when gather based xla reduction is implemented + #@staticmethod + #def logging_outputs_can_be_summed() -> bool: + def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return False + # XXX: Gather based reduction not implemented for xla yet. + # So we fall to sum based reduction for xla. + return self.xla diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index de08669851..ddd5642c7e 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -1,35 +1,80 @@ -import os.path as op +from pathlib import Path from typing import BinaryIO, Optional, Tuple, Union import numpy as np +import torch + + +SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} + + +def _convert_to_mono( + waveform: torch.FloatTensor, sample_rate: int +) -> torch.FloatTensor: + if waveform.shape[0] > 1: + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError( + "Please install torchaudio to convert multi-channel audios" + ) + effects = [['channels', '1']] + return ta_sox.apply_effects_tensor(waveform, sample_rate, effects)[0] + return waveform + + +def convert_to_mono(waveform: np.ndarray, sample_rate: int) -> np.ndarray: + if waveform.shape[0] > 1: + _waveform = torch.from_numpy(waveform) + return _convert_to_mono(_waveform, sample_rate).numpy() + return waveform def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization=True + path_or_fp: Union[str, BinaryIO], normalization=True, mono=True, + frames=-1, start=0, always_2d=True ) -> Tuple[np.ndarray, int]: - """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. + """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. Args: path_or_fp (str or BinaryIO): the path or file-like object normalization (bool): Normalize values to [-1, 1] (Default: True) + mono (bool): convert multi-channel audio to mono-channel one + frames (int): the number of frames to read. (-1 for reading all) + start (int): Where to start reading. A negative value counts from the end. + always_2d (bool): always return 2D array even for mono-channel audios + Returns: + waveform (numpy.ndarray): 1D or 2D waveform (channels x length) + sample_rate (float): sample rate """ if isinstance(path_or_fp, str): - ext = op.splitext(op.basename(path_or_fp))[1] - if ext not in {".flac", ".wav"}: + ext = Path(path_or_fp).suffix + if ext not in SF_AUDIO_FILE_EXTENSIONS: raise ValueError(f"Unsupported audio format: {ext}") try: import soundfile as sf except ImportError: - raise ImportError("Please install soundfile to load WAV/FLAC file") + raise ImportError( + "Please install soundfile to load WAV/FLAC/OGG Vorbis audios" + ) - waveform, sample_rate = sf.read(path_or_fp, dtype="float32") + waveform, sample_rate = sf.read( + path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start + ) + waveform = waveform.T # T x C -> C x T + if mono and waveform.shape[0] > 1: + waveform = convert_to_mono(waveform, sample_rate) if not normalization: waveform *= 2 ** 15 # denormalized to 16-bit signed integers + if not always_2d: + waveform = waveform.squeeze(axis=0) return waveform, sample_rate -def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_kaldi_fbank( + waveform: np.ndarray, sample_rate: int, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via PyKaldi.""" try: from kaldi.feat.mel import MelBanksOptions @@ -45,19 +90,19 @@ def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: opts.mel_opts = mel_opts opts.frame_opts = frame_opts fbank = Fbank(opts=opts) - features = fbank.compute(Vector(waveform), 1.0).numpy() + features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() return features except ImportError: return None -def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: +def _get_torchaudio_fbank( + waveform: np.ndarray, sample_rate, n_bins=80 +) -> Optional[np.ndarray]: """Get mel-filter bank features via TorchAudio.""" try: - import torch import torchaudio.compliance.kaldi as ta_kaldi - - waveform = torch.from_numpy(waveform).unsqueeze(0) + waveform = torch.from_numpy(waveform) features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) @@ -71,11 +116,11 @@ def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: (faster CPP implementation) to TorchAudio (Python implementation). Note that Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the waveform should not be normalized.""" - sound, sample_rate = get_waveform(path_or_fp, normalization=False) + waveform, sample_rate = get_waveform(path_or_fp, normalization=False) - features = _get_kaldi_fbank(sound, sample_rate, n_bins) + features = _get_kaldi_fbank(waveform, sample_rate, n_bins) if features is None: - features = _get_torchaudio_fbank(sound, sample_rate, n_bins) + features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) if features is None: raise ImportError( "Please install pyKaldi or torchaudio to enable " diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index ac5acd03bb..d0ff604e2b 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset +from .. import FairseqDataset, BaseWrapperDataset +from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes logger = logging.getLogger(__name__) @@ -23,11 +24,12 @@ def __init__( self, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__() @@ -37,10 +39,17 @@ def __init__( max_sample_size if max_sample_size is not None else sys.maxsize ) self.min_sample_size = min_sample_size - self.min_length = min_length self.pad = pad self.shuffle = shuffle self.normalize = normalize + self.compute_mask_indices = compute_mask_indices + if self.compute_mask_indices: + self.mask_compute_kwargs = mask_compute_kwargs + self._features_size_map = {} + self._C = mask_compute_kwargs['encoder_embed_dim'] + self._conv_feature_layers = eval( + mask_compute_kwargs['conv_feature_layers'] + ) def __getitem__(self, index): raise NotImplementedError() @@ -72,6 +81,45 @@ def crop_to_max_size(self, wav, target_size): end = size - diff + start return wav[start:end] + def _compute_mask_indices(self, dims, padding_mask): + B, T, C = dims + mask_indices, mask_channel_indices = None, None + if self.mask_compute_kwargs['mask_prob'] > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_compute_kwargs['mask_prob'], + self.mask_compute_kwargs['mask_length'], + self.mask_compute_kwargs['mask_selection'], + self.mask_compute_kwargs['mask_other'], + min_masks=2, + no_overlap=self.mask_compute_kwargs['no_mask_overlap'], + min_space=self.mask_compute_kwargs['mask_min_space'], + ) + mask_indices = torch.from_numpy(mask_indices) + if self.mask_compute_kwargs['mask_channel_prob'] > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_compute_kwargs['mask_channel_prob'], + self.mask_compute_kwargs['mask_channel_length'], + self.mask_compute_kwargs['mask_channel_selection'], + self.mask_compute_kwargs['mask_channel_other'], + no_overlap=self.mask_compute_kwargs['no_mask_channel_overlap'], + min_space=self.mask_compute_kwargs['mask_channel_min_space'], + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .unsqueeze(1) + .expand(-1, T, -1) + ) + + return mask_indices, mask_channel_indices + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + def collater(self, samples): samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: @@ -103,9 +151,55 @@ def collater(self, samples): collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} + out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask - return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + if hasattr(self, 'num_buckets') and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s['id']] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input['source'] = self._bucket_tensor( + collated_sources, num_pad, 0 + ) + input['padding_mask'] = self._bucket_tensor( + padding_mask, num_pad, True + ) + + if self.compute_mask_indices: + B = input['source'].size(0) + T = self._get_mask_indices_dims(input['source'].size(-1)) + padding_mask_reshaped = input['padding_mask'].clone() + extra = padding_mask_reshaped.size(1) % T + if extra > 0: + padding_mask_reshaped = padding_mask_reshaped[:, :-extra] + padding_mask_reshaped = padding_mask_reshaped.view( + padding_mask_reshaped.size(0), T, -1 + ) + padding_mask_reshaped = padding_mask_reshaped.all(-1) + input['padding_count'] = ( + padding_mask_reshaped.sum(-1).max().item() + ) + mask_indices, mask_channel_indices = self._compute_mask_indices( + (B, T, self._C), padding_mask_reshaped, + ) + input["mask_indices"] = mask_indices + input["mask_channel_indices"] = mask_channel_indices + out['sample_size'] = mask_indices.sum().item() + + out["net_input"] = input + return out + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in self._conv_feature_layers: + L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] def num_tokens(self, index): return self.size(index) @@ -136,20 +230,23 @@ def __init__( manifest_path, sample_rate, max_sample_size=None, - min_sample_size=None, + min_sample_size=0, shuffle=True, - min_length=0, pad=False, normalize=False, + num_buckets=0, + compute_mask_indices=False, + **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, - min_length=min_length, pad=pad, normalize=normalize, + compute_mask_indices=compute_mask_indices, + **mask_compute_kwargs, ) self.fnames = [] @@ -162,14 +259,32 @@ def __init__( items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) - if min_length is not None and sz < min_length: + if min_sample_size is not None and sz < min_sample_size: skipped += 1 continue self.fnames.append(items[0]) self.line_inds.add(i) self.sizes.append(sz) + self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 39d22c7a5e..c6c64db084 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -153,7 +153,8 @@ def get_features_or_waveform_from_uncompressed_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_flac_or_wav_data(data): - features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f) + features_or_waveform = \ + get_waveform(f, always_2d=False)[0] if need_waveform else get_fbank(f) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform @@ -178,7 +179,7 @@ def get_features_or_waveform(path: str, need_waveform=False): if len(extra) == 0: if need_waveform: - return get_waveform(_path) + return get_waveform(_path, always_2d=False) return get_features_from_npy_or_audio(_path) elif len(extra) == 2: extra = [int(i) for i in extra] diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index cda8834ac8..0f94100148 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -6,6 +6,7 @@ import numpy as np import torch.nn.functional as F from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes class BucketPadLengthDataset(BaseWrapperDataset): @@ -29,42 +30,43 @@ def __init__( num_buckets, pad_idx, left_pad, + tensor_key=None, ): super().__init__(dataset) self.pad_idx = pad_idx self.left_pad = left_pad assert num_buckets > 0 - self.buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation="lower", - )[1:] - ) + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key - def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item - self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] - def __getitem__(self, index): - item = self.dataset[index] - bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - item.size(-1) + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) return F.pad( - item, + tensor, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + @property def sizes(self): return self._bucketed_sizes diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 1a83063542..01c743c3e8 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -17,6 +17,8 @@ import numpy as np import torch +from fairseq.file_io import PathManager + logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def infer_language_pair(path): """Infer language pair from filename: .-.(...).idx""" src, dst = None, None - for filename in os.listdir(path): + for filename in PathManager.ls(path): parts = filename.split(".") if len(parts) >= 3 and len(parts[1].split("-")) == 2: return parts[1].split("-") @@ -307,8 +309,16 @@ def batch_by_size( "Please build Cython components with: `pip install --editable .` " "or `python setup.py build_ext --inplace`" ) + except ValueError: + raise ValueError( + "Please build (or rebuild) Cython components with: `pip install " + " --editable .` or `python setup.py build_ext --inplace`." + ) - max_tokens = max_tokens if max_tokens is not None else -1 + # added int() to avoid TypeError: an integer is required + max_tokens = ( + int(max_tokens) if max_tokens is not None else -1 + ) max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple @@ -514,3 +524,25 @@ def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: return ~lengths_to_padding_mask(lens) + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index a821417321..802e37a7ff 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -15,12 +15,23 @@ from . import FairseqDataset +from typing import Union -def __best_fitting_dtype(vocab_size=None): - if vocab_size is not None and vocab_size < 65500: + +def best_fitting_int_dtype( + max_int_to_represent, +) -> Union[np.uint16, np.uint32, np.int64]: + + if max_int_to_represent is None: + return np.uint32 # Safe guess + elif max_int_to_represent < 65500: return np.uint16 + elif max_int_to_represent < 4294967295: + return np.uint32 else: - return np.int32 + return np.int64 + # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly + # https://github.com/numpy/numpy/issues/5745 def get_available_dataset_impl(): @@ -48,7 +59,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) + out_file, dtype=best_fitting_int_dtype(vocab_size) ) elif impl == "fasta": raise NotImplementedError @@ -92,7 +103,7 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) -dtypes = { +_code_to_dtype = { 1: np.uint8, 2: np.int8, 3: np.int16, @@ -101,12 +112,14 @@ def write_longs(f, a): 6: np.float, 7: np.double, 8: np.uint16, + 9: np.uint32, + 10: np.uint64, } -def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: +def _dtype_header_code(dtype) -> int: + for k in _code_to_dtype.keys(): + if _code_to_dtype[k] == dtype: return k raise ValueError(dtype) @@ -141,7 +154,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(" 0: + sampled_indices += list( + np.concatenate( + ( + np.repeat( + np.arange(self.dataset_offsets[i], high), num_copies + ), + dataset_indices, + ) + ) + ) + else: + sampled_indices += list(dataset_indices) + + assert ( + len(sampled_indices) == self.total_num_instances + ), f"{len(sampled_indices)} vs {self.total_num_instances}" + + np.random.shuffle(sampled_indices) if self.sort_indices: sampled_indices.sort(key=lambda i: self.num_tokens(i)) - return np.array(sampled_indices, dtype=np.int64) - - def _sample(self, indices, counters): - # First pick dataset - dataset_idx = np.random.choice(len(self.distribution), p=self.distribution) - - # Then get dataset internal index - idx = indices[dataset_idx][counters[dataset_idx]] - - # Convert to multi-datasets index - idx += self.dataset_offsets[dataset_idx] - counters[dataset_idx] += 1 - - # Reset if we reach end - if counters[dataset_idx] == len(self.dataset_list[dataset_idx]): - counters[dataset_idx] = 0 - indices[dataset_idx] = np.random.permutation( - len(self.dataset_list[dataset_idx]) + logger.info( + "multi_corpus_dataset ordered_indices took {}s".format( + time.time() - start + ) ) - - return idx + return np.array(sampled_indices, dtype=np.int64) def _map_index(self, index: int): """ @@ -125,18 +148,26 @@ def __len__(self): return self.total_num_instances def __getitem__(self, index): - index, key = self._map_index(index) - return self.datasets[key][index] + new_index, key = self._map_index(index) + try: + item = self.datasets[key][new_index] + item["full_id"] = index + return item + except Exception as e: + e.args = (f"Error from {key} dataset", *e.args) + raise def collater(self, samples): """ - Since we enforce all datsets to be the same, collating is just - picking the first one and doing collate. + If we are doing batch sampling, then pick the right collater to use. + + Otherwise we assume all collaters are the same. """ if len(samples) == 0: return None + _, key = self._map_index(samples[0]["full_id"]) - return list(self.datasets.values())[0].collater(samples) + return self.datasets[key].collater(samples) def num_tokens(self, index: int): index, key = self._map_index(index) @@ -164,3 +195,34 @@ def supports_fetch_outside_dataloader(self): self.datasets[key].supports_fetch_outside_dataloader for key in self.datasets ) + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not self.batch_sample: + return super().batch_by_size( + indices, max_tokens, max_sentences, required_batch_size_multiple + ) + + dataset_indices = {key: [] for key in self.datasets} + for i in indices: + _, key = self._map_index(i) + dataset_indices[key].append(i) + + batches = [] + for key in dataset_indices: + cur_batches = super().batch_by_size( + np.array(dataset_indices[key], dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info(f"Created {len(cur_batches)} batches for dataset {key}") + batches += cur_batches + + # Assume shuffling is handled in fairseq/data/iterators.py + return batches diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index f4bb6472d7..b9fab3b739 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -3,11 +3,23 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + import subprocess +import json import tempfile +import hashlib +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False -class PlasmaArray(object): +class PlasmaArray: """ Wrapper around numpy arrays that automatically moves the data to shared memory upon serialization. This is particularly helpful when passing numpy @@ -31,12 +43,7 @@ def __init__(self, array): @property def plasma(self): if self._plasma is None and not self.disable: - try: - import pyarrow.plasma as plasma - - self._plasma = plasma - except ImportError: - self._plasma = None + self._plasma = plasma return self._plasma def start_server(self): @@ -47,13 +54,7 @@ def start_server(self): self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name self._server = subprocess.Popen( - [ - "plasma_store", - "-m", - str(int(1.05 * self.array.nbytes)), - "-s", - self.path, - ] + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] ) @property @@ -64,6 +65,7 @@ def client(self): return self._client def __getstate__(self): + """Called on pickle load""" if self.plasma is None: return self.__dict__ if self.object_id is None: @@ -78,6 +80,7 @@ def __getstate__(self): return state def __setstate__(self, state): + """Called on pickle save""" self.__dict__.update(state) if self.plasma is None: return @@ -89,3 +92,106 @@ def __del__(self): self._server = None self._server_tmp.close() self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024 ** 3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index aa33f9d06f..d2c65fd7e0 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -6,6 +6,8 @@ import numpy as np import torch from fairseq.data import FairseqDataset, plasma_utils +from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple class TokenBlockDataset(FairseqDataset): @@ -41,7 +43,46 @@ def __init__( break_mode=None, include_targets=False, document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" try: from fairseq.data.token_block_utils_fast import ( _get_slice_indices_fast, @@ -53,15 +94,6 @@ def __init__( "or `python setup.py build_ext --inplace`" ) - super().__init__() - self.dataset = dataset - self.pad = pad - self.eos = eos - self.include_targets = include_targets - - assert len(dataset) == len(sizes) - assert len(dataset) > 0 - if isinstance(sizes, list): sizes = np.array(sizes, dtype=np.int64) else: @@ -78,7 +110,7 @@ def __init__( slice_indices = _get_slice_indices_fast( sizes, str(break_mode), block_size, document_sep_len ) - self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + _sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices if break_mode == "eos": @@ -87,7 +119,7 @@ def __init__( [ np.arange(len(sizes)), # starting index in dataset np.zeros( - len(sizes), dtype=np.long + len(sizes), dtype=np.compat.long ), # starting offset within starting index np.arange(len(sizes)), # ending index in dataset ], @@ -95,12 +127,15 @@ def __init__( ) else: block_to_dataset_index = _get_block_to_dataset_index_fast( - sizes, - slice_indices, + sizes, slice_indices, ) - self._slice_indices = plasma_utils.PlasmaArray(slice_indices) - self._sizes = plasma_utils.PlasmaArray(self._sizes) - self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) + size_dtype = np.uint16 if block_size < 65535 else np.uint32 + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices @property def slice_indices(self): @@ -124,7 +159,6 @@ def __getitem__(self, index): buffer = torch.cat( [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] ) - slice_s, slice_e = self.slice_indices[index] length = slice_e - slice_s s, e = start_offset, start_offset + length diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f66e98fe83..3c29be9197 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -104,15 +104,10 @@ class CommonConfig(FairseqDataclass): ) wandb_project: Optional[str] = field( default=None, - metadata={ - "help": "Weights and Biases project name to use for logging" - }, + metadata={"help": "Weights and Biases project name to use for logging"}, ) azureml_logging: Optional[bool] = field( - default=False, - metadata={ - "help": "Log scalars to AzureML context" - }, + default=False, metadata={"help": "Log scalars to AzureML context"}, ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} @@ -192,6 +187,15 @@ class CommonConfig(FairseqDataclass): "main method can return a value (useful for sweeps)" }, ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) @dataclass @@ -263,7 +267,7 @@ class DistributedTrainingConfig(FairseqDataclass): metadata={ "help": "kill the job if no progress is made in N seconds; " "set to -1 to disable" - } + }, ) broadcast_buffers: bool = field( default=False, @@ -355,7 +359,19 @@ class DistributedTrainingConfig(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, metadata={"help": "offload FP32 params to CPU"} + ) @dataclass @@ -607,8 +623,17 @@ class CheckpointConfig(FairseqDataclass): "(default: only load on rank 0 and broadcast to other devices)" }, ) + write_checkpoints_asynchronously: bool = field( + default=False, + metadata={ + "help": ( + "Write checkpoints asynchronously in a separate " + "thread. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--save-async", + }, + ) model_parallel_size: int = II("common.model_parallel_size") - distributed_rank: int = II("distributed_training.distributed_rank") @dataclass @@ -641,12 +666,10 @@ class FairseqBMUFConfig(FairseqDataclass): @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( - default=5, - metadata={"help": "beam size"}, + default=5, metadata={"help": "beam size"}, ) nbest: int = field( - default=1, - metadata={"help": "number of hypotheses to output"}, + default=1, metadata={"help": "number of hypotheses to output"}, ) max_len_a: float = field( default=0, @@ -661,24 +684,19 @@ class GenerationConfig(FairseqDataclass): }, ) min_len: int = field( - default=1, - metadata={"help": "minimum generation length"}, + default=1, metadata={"help": "minimum generation length"}, ) match_source_len: bool = field( - default=False, - metadata={"help": "generations should match the source length"}, + default=False, metadata={"help": "generations should match the source length"}, ) unnormalized: bool = field( - default=False, - metadata={"help": "compare unnormalized hypothesis scores"}, + default=False, metadata={"help": "compare unnormalized hypothesis scores"}, ) no_early_stop: bool = field( - default=False, - metadata={"help": "deprecated"}, + default=False, metadata={"help": "deprecated"}, ) no_beamable_mm: bool = field( - default=False, - metadata={"help": "don't use BeamableMM in attention layers"}, + default=False, metadata={"help": "don't use BeamableMM in attention layers"}, ) lenpen: float = field( default=1, @@ -700,12 +718,10 @@ class GenerationConfig(FairseqDataclass): }, ) sacrebleu: bool = field( - default=False, - metadata={"help": "score with sacrebleu"}, + default=False, metadata={"help": "score with sacrebleu"}, ) score_reference: bool = field( - default=False, - metadata={"help": "just score the reference translation"}, + default=False, metadata={"help": "just score the reference translation"}, ) prefix_size: int = field( default=0, @@ -739,12 +755,10 @@ class GenerationConfig(FairseqDataclass): }, ) temperature: float = field( - default=1.0, - metadata={"help": "temperature for generation"}, + default=1.0, metadata={"help": "temperature for generation"}, ) diverse_beam_groups: int = field( - default=-1, - metadata={"help": "number of groups for Diverse Beam Search"}, + default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, ) diverse_beam_strength: float = field( default=0.5, @@ -763,16 +777,13 @@ class GenerationConfig(FairseqDataclass): }, ) print_step: bool = field( - default=False, - metadata={"help": "print steps"}, + default=False, metadata={"help": "print steps"}, ) lm_path: Optional[str] = field( - default=None, - metadata={"help": "path to lm checkpoint for lm fusion"}, + default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, ) lm_weight: float = field( - default=0.0, - metadata={"help": "weight for lm probs for lm fusion"}, + default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, ) # arguments for iterative refinement generator @@ -781,8 +792,7 @@ class GenerationConfig(FairseqDataclass): metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, ) iter_decode_max_iter: int = field( - default=10, - metadata={"help": "maximum iterations for iterative refinement."}, + default=10, metadata={"help": "maximum iterations for iterative refinement."}, ) iter_decode_force_max_iter: bool = field( default=False, @@ -809,8 +819,7 @@ class GenerationConfig(FairseqDataclass): }, ) retain_dropout: bool = field( - default=False, - metadata={"help": "Use dropout at inference time"}, + default=False, metadata={"help": "Use dropout at inference time"}, ) # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # retain_dropout_modules: Optional[List[str]] = field( @@ -835,8 +844,7 @@ class GenerationConfig(FairseqDataclass): @dataclass class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, - metadata={"help": "path(s) to model file(s), colon separated"}, + default=None, metadata={"help": "path(s) to model file(s), colon separated"}, ) post_process: Optional[str] = field( default=None, @@ -898,8 +906,7 @@ class InteractiveConfig(FairseqDataclass): }, ) input: str = field( - default="-", - metadata={"help": "file to read from; use - for stdin"}, + default="-", metadata={"help": "file to read from; use - for stdin"}, ) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 93bc6d03cb..faba0862fa 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -37,6 +37,7 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum([ "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale "legacy_ddp", "no_c10d", # alias for legacy_ddp "pytorch_ddp", diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index a4d4a412dd..27c9006fdb 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -43,7 +43,9 @@ def interpret_dc_type(field_type): return str typestring = str(field_type) - if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring) or typestring.startswith("typing.Optional"): + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): return field_type.__args__[0] return field_type @@ -235,15 +237,17 @@ def get_default(f): and not (isinstance(val, str) and val.startswith("${")) ): # if type is int but val is float, then we will crash later - try to convert here - if hasattr(v.type, '__args__'): + if hasattr(v.type, "__args__"): t_args = v.type.__args__ - if len(t_args) == 1: + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): val = list(map(t_args[0], val)) - elif val is not None and (field_type is int or field_type is bool or field_type is float): + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): try: val = field_type(val) except: - pass # ignore errors here, they are often from interpolation args + pass # ignore errors here, they are often from interpolation args if val is None: overrides.append("{}.{}=null".format(sub_node, k)) @@ -430,7 +434,7 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): if k in cfg and isinstance(cfg[k], DictConfig): if k in overrides and isinstance(overrides[k], dict): for ok, ov in overrides[k].items(): - if isinstance(ov, dict): + if isinstance(ov, dict) and cfg[k][ok] is not None: overwrite_args_by_name(cfg[k][ok], ov) else: cfg[k][ok] = ov diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index 7f4016e38c..d0b96b734c 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .module_proxy_wrapper import ModuleProxyWrapper from .tpu_distributed_data_parallel import TPUDistributedDataParallel @@ -11,6 +12,9 @@ __all__ = [ "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", "LegacyDistributedDataParallel", "ModuleProxyWrapper", "TPUDistributedDataParallel", diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000000..9c290b3fda --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch + +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + from fairscale.utils.testing import DummyProcessGroup + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": True, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + } + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=use_sharded_state, + **fsdp_config, + ): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, **kwargs) + else: + return module + else: + return wrap(module, **kwargs) + except ImportError: + return module diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index e3d8e1e0d3..970b784915 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -281,7 +281,6 @@ def distributed_init(cfg: FairseqConfig): cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers - xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) @@ -325,6 +324,9 @@ def distributed_main(i, main, cfg: FairseqConfig, kwargs): main(cfg, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.barrier(get_global_group()) + def call_main(cfg: FairseqConfig, main, **kwargs): if cfg.distributed_training.distributed_init_method is None: @@ -354,7 +356,10 @@ def call_main(cfg: FairseqConfig, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, cfg, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), ) else: # single GPU main diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 7d6c28dccd..9a78ab505d 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -32,6 +32,8 @@ except ImportError: FVCorePathManager = None +IOPathPathManager = None + class PathManager: """ @@ -148,3 +150,47 @@ def supports_rename(path: str) -> bool: @staticmethod def rename(src: str, dst: str): os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathPathManager + if not IOPathPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath.common.file_io import PathManager + IOPathPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathPathManager + if IOPathPathManager: + return IOPathPathManager.async_close() + return False diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 7b56e31592..2bb1da086f 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -286,3 +286,11 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index dc061a1821..0ae2bc006d 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -123,7 +123,7 @@ def __init__(self, iterable, epoch=None, prefix=None): if epoch is not None: self.prefix += "epoch {:03d}".format(epoch) if prefix is not None: - self.prefix += " | {}".format(prefix) + self.prefix += (" | " if self.prefix != "" else "") + prefix def __len__(self): return len(self.iterable) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 7873611214..7f30dd98bb 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -39,15 +39,47 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 +TORCH_PIPE = False +RPC_INIT = False + +def import_pipe(): + global TORCH_PIPE + global RPC_INIT + try: + from torch.distributed.pipeline.sync import Pipe # noqa + global Pipe + from torch.distributed.pipeline.sync.utils import partition_model + global partition_model + from torch.distributed import rpc + import tempfile + TORCH_PIPE = True + # Initialize single process RPC agent since TORCH_PIPE requires + # RRef. RRef depends on RPC being initialized and as a result we initialize + # RPC with a single node. + tmpfile = tempfile.NamedTemporaryFile() + if not RPC_INIT: + rpc.init_rpc( + name="worker", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(tmpfile.name), + ) + ) + RPC_INIT = True + logger.info('Using torch pipe') + except ImportError: + try: + from fairscale.nn import Pipe # noqa + logger.info('Using fairscale pipe') + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") @register_model("pipeline_parallel_transformer") class PipelineParallelTransformerModel(BaseFairseqModel): def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) @@ -65,13 +97,20 @@ def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices - self.model = Pipe( - nn.Sequential(*module_list), - balance=balance, - devices=devices, - chunks=chunks, - checkpoint=checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + partition_model(nn.Sequential(*module_list), balance, devices), + chunks=chunks, + checkpoint=checkpoint, + ) + else: + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) self.encoder_max_positions = self.max_positions_helper( encoder.embedding_layer, "max_source_positions" ) @@ -87,7 +126,10 @@ def forward(self, src_tokens, src_lengths, prev_output_tokens): if self.training: input_lst = [src_tokens, src_lengths, prev_output_tokens] input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) - return self.model(input) + if TORCH_PIPE: + return self.model(input).local_value() + else: + return self.model(input) else: assert self.encoder is not None and self.decoder is not None, ( "encoder and decoder need to be initialized by " @@ -425,10 +467,7 @@ class TransformerEncoder(FairseqEncoder): def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) @@ -449,13 +488,20 @@ def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): f"Sum of encoder_balance={encoder_balance} is not equal " + f"to num_encoder_modules={len(encoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*encoder_module_list), - balance=encoder_balance, - devices=encoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward(self, src_tokens, src_lengths): """ @@ -485,7 +531,10 @@ def forward(self, src_tokens, src_lengths): input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - encoder_out = self.model(input_tuple) + if TORCH_PIPE: + encoder_out = self.model(input_tuple).local_value() + else: + encoder_out = self.model(input_tuple) else: encoder_embed_output_tuple = self.embedding_layer(input_tuple) encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) @@ -561,10 +610,7 @@ def __init__( ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) - try: - from fairscale.nn import Pipe - except ImportError: - raise ImportError("Please install fairscale with: pip install fairscale") + import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) @@ -586,13 +632,20 @@ def __init__( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) - self.model = Pipe( - module=nn.Sequential(*decoder_module_list), - balance=decoder_balance, - devices=decoder_devices, - chunks=args.pipeline_chunks, - checkpoint=args.pipeline_checkpoint, - ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) def forward( self, @@ -622,7 +675,10 @@ def forward( ) if self.use_pipeline: input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) - return (self.model(input_tuple),) + if TORCH_PIPE: + return (self.model(input_tuple).local_value(),) + else: + return (self.model(input_tuple),) else: embed_layer_output = self.embedding_layer(input_tuple) state = self.decoder_layers(embed_layer_output) diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py index 68ad88d2a5..77a80ef720 100644 --- a/fairseq/model_parallel/models/roberta/model.py +++ b/fairseq/model_parallel/models/roberta/model.py @@ -12,16 +12,15 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoder -from fairseq.models import FairseqEncoder, register_model, register_model_architecture +from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder +from fairseq.models import register_model, register_model_architecture from fairseq.models.roberta import ( - RobertaClassificationHead, + roberta_base_architecture, + roberta_prenorm_architecture, RobertaEncoder, - RobertaLMHead, RobertaModel, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder -from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.modules import LayerNorm try: @@ -29,7 +28,7 @@ copy_to_model_parallel_region, gather_from_model_parallel_region, ColumnParallelLinear, - RowParallelLinear, + VocabParallelEmbedding, ) has_megatron_submodule = True @@ -48,7 +47,15 @@ def __init__(self, args, encoder): @staticmethod def add_args(parser): - super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser) + RobertaModel.add_args(parser) + parser.add_argument( + "--no-final-layer-norm", + action="store_true", + help=( + "don't add final layernorm (only applicable when " + "--encoder-normalize-before=True" + ), + ) @classmethod def build_model(cls, args, task): @@ -165,121 +172,52 @@ def forward(self, features, **kwargs): return x -class ModelParallelRobertaEncoder(FairseqEncoder): - """RoBERTa encoder. - - Implements the :class:`~fairseq.models.FairseqDecoder` interface required - by :class:`~fairseq.models.FairseqLanguageModel`. - """ +class ModelParallelRobertaEncoder(RobertaEncoder): + """RoBERTa encoder.""" def __init__(self, args, dictionary): - super().__init__(dictionary) - self.args = args - - # RoBERTa is a sentence encoder model, so users will intuitively trim - # encoder layers. However, the implementation uses the fairseq decoder, - # so we fix here. - if args.encoder_layers_to_keep: - args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - args.decoder_layers_to_keep = args.encoder_layers_to_keep - args.encoder_layers_to_keep = None - - self.sentence_encoder = ModelParallelTransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=False, - apply_bert_init=False, - activation_fn=args.activation_fn, - ) - self.lm_head = ModelParallelRobertaLMHead( - embed_dim=args.encoder_embed_dim, - output_dim=len(dictionary), - activation_fn=args.activation_fn, - weight=self.sentence_encoder.embed_tokens.weight, - ) - - def forward( - self, - src_tokens, - features_only=False, - return_all_hiddens=False, - masked_tokens=None, - **unused - ): - """ - Args: - src_tokens (LongTensor): input tokens of shape `(batch, src_len)` - features_only (bool, optional): skip LM head and just return - features. If True, the output will be of shape - `(batch, src_len, embed_dim)`. - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - - Returns: - tuple: - - the LM output of shape `(batch, src_len, vocab)` - - a dictionary of additional data, where 'inner_states' - is a list of hidden states. Note that the hidden - states have shape `(src_len, batch, vocab)`. - """ - x, extra = self.extract_features( - src_tokens, return_all_hiddens=return_all_hiddens - ) - if not features_only: - x = self.output_layer(x, masked_tokens=masked_tokens) - return x, extra + super().__init__(args, dictionary) + assert not self.args.untie_weights_roberta - def extract_features(self, src_tokens, return_all_hiddens=False, **unused): - inner_states, _ = self.sentence_encoder( - src_tokens, - last_state_only=not return_all_hiddens, - ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - def output_layer(self, features, masked_tokens=None, **unused): - return self.lm_head(features, masked_tokens) + def build_encoder(self, args, dictionary, embed_tokens): + return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) - def max_positions(self): - """Maximum output length supported by the encoder.""" - return self.args.max_positions + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta") def base_architecture(args): - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) + # model parallel RoBERTa defaults to "Pre-LN" formulation + roberta_prenorm_architecture(args) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) +# earlier versions of model parallel RoBERTa removed the final layer norm +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") +def model_parallel_roberta_v1_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) + base_architecture(args) + + +@register_model_architecture( + "model_parallel_roberta", "model_parallel_roberta_postnorm" +) +def model_parallel_roberta_postnorm_architecture(args): + # the original BERT/RoBERTa uses the "Post-LN" formulation + roberta_base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") -def roberta_base_architecture(args): +def model_parallel_roberta_base_architecture(args): base_architecture(args) @register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") -def roberta_large_architecture(args): +def model_parallel_roberta_large_architecture(args): args.encoder_layers = getattr(args, "encoder_layers", 24) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py index 4f34645226..6b330ef1b7 100644 --- a/fairseq/model_parallel/models/transformer.py +++ b/fairseq/model_parallel/models/transformer.py @@ -6,7 +6,6 @@ import logging import torch.nn as nn -import torch.nn.functional as F from fairseq.model_parallel.modules import ( ModelParallelTransformerDecoderLayer, ModelParallelTransformerEncoderLayer, @@ -86,6 +85,12 @@ class ModelParallelTransformerEncoder(TransformerEncoder): is a :class:`ModelParallelTransformerEncoderLayer`. """ + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + + if args.no_final_layer_norm: + self.layer_norm = None + def build_encoder_layer(self, args): return ModelParallelTransformerEncoderLayer(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py index fb45b3c9e0..11603217a1 100644 --- a/fairseq/model_parallel/modules/__init__.py +++ b/fairseq/model_parallel/modules/__init__.py @@ -9,15 +9,9 @@ ModelParallelTransformerEncoderLayer, ModelParallelTransformerDecoderLayer, ) -from .transformer_sentence_encoder_layer import ( - ModelParallelTransformerSentenceEncoderLayer, -) -from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder __all__ = [ "ModelParallelMultiheadAttention", "ModelParallelTransformerEncoderLayer", "ModelParallelTransformerDecoderLayer", - "ModelParallelTransformerSentenceEncoder", - "ModelParallelTransformerSentenceEncoderLayer", ] diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py deleted file mode 100644 index a5d50a33c6..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer -from fairseq.modules import ( - LayerNorm, - MultiheadAttention, - PositionalEmbedding, - TransformerSentenceEncoder, -) - - -try: - from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): - """ - Implementation for a Model Parallel Bi-directional Transformer based - Sentence Encoder used in BERT/XLM style pre-trained models. - """ - - def build_embedding(self, vocab_size, embedding_dim, padding_idx): - return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) - - def build_transformer_sentence_encoder_layer( - self, - embedding_dim, - ffn_embedding_dim, - num_attention_heads, - dropout, - attention_dropout, - activation_dropout, - activation_fn, - export, - **unused, - ): - return ModelParallelTransformerSentenceEncoderLayer( - embedding_dim=embedding_dim, - ffn_embedding_dim=ffn_embedding_dim, - num_attention_heads=num_attention_heads, - dropout=dropout, - attention_dropout=attention_dropout, - activation_dropout=activation_dropout, - activation_fn=activation_fn, - export=export, - ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py deleted file mode 100644 index e10bf52332..0000000000 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from fairseq import utils -from fairseq.model_parallel.modules import ModelParallelMultiheadAttention -from fairseq.modules import TransformerSentenceEncoderLayer - - -try: - from fairseq.model_parallel.megatron.mpu import ( - ColumnParallelLinear, - RowParallelLinear, - ) - - has_megatron_submodule = True -except (ImportError, ModuleNotFoundError): - has_megatron_submodule = False - - -class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): - """ - Implements a Model Parallel Transformer Encoder Layer used in - BERT/XLM style pre-trained models. - """ - - def build_fc1(self, input_dim, output_dim, **unused): - return ColumnParallelLinear(input_dim, output_dim, gather_output=False) - - def build_fc2(self, input_dim, output_dim, **unused): - return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) - - def build_self_attention( - self, - embed_dim, - num_attention_heads, - dropout, - **kwargs, - ): - return ModelParallelMultiheadAttention( - embed_dim, num_attention_heads, dropout=dropout, self_attention=True - ) - - def forward( - self, - x: torch.Tensor, - self_attn_mask: torch.Tensor = None, - self_attn_padding_mask: torch.Tensor = None, - ): - """ - LayerNorm is applied either before or after the self-attention/ffn - modules similar to the original Transformer imlementation. - """ - residual = x - x = self.self_attn_layer_norm(x) - x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - need_weights=False, - attn_mask=self_attn_mask, - ) - x = self.dropout_module(x) - x = residual + x - - residual = x - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) - x = self.dropout_module(x) - x = residual + x - return x, None diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 1ff170a782..2ddeb763a3 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -92,22 +92,27 @@ def generate( tokenized_sentences: List[torch.LongTensor], *args, inference_step_args=None, + skip_invalid_size_inputs=False, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: inference_step_args = inference_step_args or {} if "prefix_tokens" in inference_step_args: raise NotImplementedError("prefix generation not implemented for BART") - else: - bsz = len(tokenized_sentences) - inference_step_args["prefix_tokens"] = tokenized_sentences[0].new_full( - (bsz, 1), fill_value=self.task.source_dictionary.bos() + res = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + src_tokens = batch['net_input']['src_tokens'] + inference_step_args["prefix_tokens"] =src_tokens.new_full( + (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() ).to(device=self.device) - return super().generate( - tokenized_sentences, - *args, - inference_step_args=inference_step_args, - **kwargs - ) + results = super().generate( + src_tokens, + *args, + inference_step_args=inference_step_args, + skip_invalid_size_inputs=skip_invalid_size_inputs, + **kwargs + ) + res.extend(results) + return res def extract_features( self, tokens: torch.LongTensor, return_all_hiddens: bool = False diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index ca157f06e9..3422faea74 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -105,12 +105,27 @@ def DistributedFairseqModel(args, model, process_group, device): ) # forward missing getattr and state_dict/load_state_dict to orig model wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) # kill hung distributed jobs after a timeout - wrapped_model = DistributedTimeoutWrapper( - wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) - ) + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) return wrapped_model diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 7eeb5c652f..4f1e8b52a2 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -64,6 +64,19 @@ def get_normalized_probs( sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: if sample is not None: diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 244cbc0c66..171a8a40f1 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -14,7 +14,6 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary from fairseq.dataclass.utils import ( convert_namespace_to_omegaconf, @@ -28,6 +27,14 @@ logger = logging.getLogger(__name__) +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance(module.unwrapped_module, expected_type), \ + f"{type(module.unwrapped_module)} != {expected_type}" + else: + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" + + class BaseFairseqModel(nn.Module): """Base class for fairseq models.""" @@ -111,6 +118,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) @@ -282,8 +292,9 @@ def __init__(self, encoder, decoder): self.encoder = encoder self.decoder = decoder - assert isinstance(self.encoder, FairseqEncoder) - assert isinstance(self.decoder, FairseqDecoder) + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ @@ -363,8 +374,8 @@ def __init__(self, encoders, decoders): assert encoders.keys() == decoders.keys() self.keys = list(encoders.keys()) for key in self.keys: - assert isinstance(encoders[key], FairseqEncoder) - assert isinstance(decoders[key], FairseqDecoder) + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) self.models = nn.ModuleDict( { @@ -450,6 +461,9 @@ def load_state_dict( model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) @@ -464,7 +478,7 @@ class FairseqLanguageModel(BaseFairseqModel): def __init__(self, decoder): super().__init__() self.decoder = decoder - assert isinstance(self.decoder, FairseqDecoder) + check_type(self.decoder, FairseqDecoder) def forward(self, src_tokens, **kwargs): """ @@ -525,7 +539,7 @@ class FairseqEncoderModel(BaseFairseqModel): def __init__(self, encoder): super().__init__() self.encoder = encoder - assert isinstance(self.encoder, FairseqEncoder) + check_type(self.encoder, FairseqEncoder) def forward(self, src_tokens, src_lengths, **kwargs): """ diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 00a5a5485f..5d2ed4902d 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -18,7 +18,8 @@ register_model, register_model_architecture, ) -from fairseq.modules import LayerNorm, TransformerSentenceEncoder +from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder +from fairseq.modules import LayerNorm from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -87,6 +88,11 @@ def add_args(parser): action="store_true", help="apply layernorm before each encoder block", ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) parser.add_argument( "--dropout", type=float, metavar="D", help="dropout probability" ) @@ -116,6 +122,11 @@ def add_args(parser): action="store_true", help="(re-)register and load heads when loading checkpoints", ) + parser.add_argument( + "--untie-weights-roberta", + action="store_true", + help="Untie weights between embeddings and classifiers in RoBERTa", + ) # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument( "--encoder-layerdrop", @@ -151,17 +162,28 @@ def add_args(parser): default=0, help="scalar quantization noise and scalar quantization at training time", ) - parser.add_argument( - "--untie-weights-roberta", - action="store_true", - help="Untie weights between embeddings and classifiers in RoBERTa", - ) + # args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020) parser.add_argument( "--spectral-norm-classification-head", action="store_true", default=False, help="Apply spectral normalization on the classification head", ) + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + "--min-params-to-wrap", + type=int, + metavar="D", + default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + ) + ) @classmethod def build_model(cls, args, task): @@ -264,6 +286,13 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict[new_k] = state_dict[k] del state_dict[k] + # rename emb_layer_norm -> layernorm_embedding + for k in list(state_dict.keys()): + if ".emb_layer_norm." in k: + new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.") + state_dict[new_k] = state_dict[k] + del state_dict[k] + # upgrade children modules super().upgrade_state_dict_named(state_dict, name) @@ -401,7 +430,11 @@ def __init__(self, args, dictionary): if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - self.sentence_encoder = self.build_encoder(args, dictionary) + embed_tokens = self.build_embedding( + len(dictionary), args.encoder_embed_dim, dictionary.pad() + ) + + self.sentence_encoder = self.build_encoder(args, dictionary, embed_tokens) self.lm_head = self.build_lm_head( embed_dim=args.encoder_embed_dim, @@ -414,26 +447,16 @@ def __init__(self, args, dictionary): ), ) - def build_encoder(self, args, dictionary): - return TransformerSentenceEncoder( - padding_idx=dictionary.pad(), - vocab_size=len(dictionary), - num_encoder_layers=args.encoder_layers, - embedding_dim=args.encoder_embed_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - layerdrop=args.encoder_layerdrop, - max_seq_len=args.max_positions, - num_segments=0, - encoder_normalize_before=True, - apply_bert_init=True, - activation_fn=args.activation_fn, - q_noise=args.quant_noise_pq, - qn_block_size=args.quant_noise_pq_block_size, - ) + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + encoder = TransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) @@ -470,13 +493,15 @@ def forward( return x, extra def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): - inner_states, _ = self.sentence_encoder( + encoder_out = self.sentence_encoder( src_tokens, - last_state_only=not return_all_hiddens, + return_all_hiddens=return_all_hiddens, token_embeddings=kwargs.get("token_embeddings", None), ) - features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C - return features, {"inner_states": inner_states if return_all_hiddens else None} + # T x B x C -> B x T x C + features = encoder_out["encoder_out"][0].transpose(0, 1) + inner_states = encoder_out["encoder_states"] if return_all_hiddens else None + return features, {"inner_states": inner_states} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) @@ -493,21 +518,50 @@ def base_architecture(args): args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) - args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + args.max_source_positions = getattr(args, "max_positions", 512) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + # BERT has a few structural differences compared to the original Transformer + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) + + # Adaptive input config + args.adaptive_input = getattr(args, "adaptive_input", False) + + # LayerDrop config + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + + # Quantization noise config + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + # R4F config args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False ) +@register_model_architecture("roberta", "roberta_prenorm") +def roberta_prenorm_architecture(args): + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + base_architecture(args) + + @register_model_architecture("roberta", "roberta_base") def roberta_base_architecture(args): base_architecture(args) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index 5d7f59b3a6..c6ae9b17ba 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -4,4 +4,5 @@ # LICENSE file in the root directory of this source tree. from .berard import * # noqa +from .convtransformer import * # noqa from .s2t_transformer import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py new file mode 100644 index 0000000000..a4cbbcdeeb --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python3 + +import logging +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq import checkpoint_utils, utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer +from torch import Tensor + +logger = logging.getLogger(__name__) + + +@register_model("convtransformer") +class ConvTransformerModel(FairseqEncoderDecoderModel): + """ + Transformer-based Speech translation model from ESPNet-ST + https://arxiv.org/abs/2004.10234 + """ + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension (extra linear layer if different from decoder embed dim)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="INT", + help="the number of output channels of conv layer", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None): + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) + if getattr(args, "load_pretrained_decoder_from", None): + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + @staticmethod + @torch.jit.unused + def set_batch_first(lprobs): + lprobs.batch_first = True + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + if self.training: + self.set_batch_first(lprobs) + return lprobs + + def output_layout(self): + return "BTD" + + """ + The forward method inherited from the base class has a **kwargs argument in + its input, which is not supported in torchscript. This method overrites the forward + method definition without **kwargs. + """ + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + return decoder_out + + +class ConvTransformerEncoder(FairseqEncoder): + """Conv + Transformer encoder""" + + def __init__(self, args): + """Construct an Encoder object.""" + super().__init__(None) + + self.dropout = args.dropout + self.embed_scale = ( + 1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim) + ) + self.padding_idx = 1 + self.in_channels = 1 + self.input_dim = args.input_feat_per_channel + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2), + torch.nn.ReLU(), + torch.nn.Conv2d( + args.conv_out_channels, + args.conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = self.infer_conv_output_dim( + self.in_channels, self.input_dim, args.conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, + args.encoder_embed_dim, + self.padding_idx, + learned=False, + ) + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def pooling_ratio(self): + return 4 + + def infer_conv_output_dim(self, in_channels, input_dim, out_channels): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x) + x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + def forward(self, src_tokens, src_lengths): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + + input_lengths = torch.min( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long() + ) + + encoder_padding_mask = lengths_to_padding_mask(input_lengths) + + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + + if not encoder_padding_mask.any(): + maybe_encoder_padding_mask = None + else: + maybe_encoder_padding_mask = encoder_padding_mask + + return { + "encoder_out": [x], + "encoder_padding_mask": [maybe_encoder_padding_mask] + if maybe_encoder_padding_mask is not None + else [], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + (encoder_out["encoder_padding_mask"][0]).index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + (encoder_out["encoder_embedding"][0]).index_select(0, new_order) + ] + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, + "encoder_padding_mask": new_encoder_padding_mask, + "encoder_embedding": new_encoder_embedding, + "encoder_states": encoder_states, + "src_tokens": [], + "src_lengths": [], + } + + +class TransformerDecoderNoExtra(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="convtransformer", arch_name="convtransformer") +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.max_source_positions = getattr(args, "max_source_positions", 3000) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim) + + +@register_model_architecture("convtransformer", "convtransformer_espnet") +def convtransformer_espnet(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py new file mode 100644 index 0000000000..e7465bc889 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -0,0 +1,488 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, List + +import torch +import torch.nn.functional as F +from fairseq.models import FairseqEncoder +from fairseq.models.speech_to_text import ( + ConvTransformerEncoder, +) +from fairseq.models.speech_to_text.utils import attention_suppression +from fairseq.models.speech_to_text.utils import ( + lengths_to_encoder_padding_mask, + segments_to_sequence, + sequence_to_segments, +) +from fairseq.modules import MultiheadAttention, TransformerEncoderLayer +from torch import nn, Tensor + +# ------------------------------------------------------------------------------ +# AugmentedMemoryConvTransformerEncoder +# ------------------------------------------------------------------------------ + + +class AugmentedMemoryConvTransformerEncoder(ConvTransformerEncoder): + def __init__(self, args): + super().__init__(args) + + args.encoder_stride = self.stride() + + self.left_context = args.left_context // args.encoder_stride + + self.right_context = args.right_context // args.encoder_stride + + self.left_context_after_stride = args.left_context // args.encoder_stride + self.right_context_after_stride = args.right_context // args.encoder_stride + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [ + AugmentedMemoryTransformerEncoderLayer(args) + for i in range(args.encoder_layers) + ] + ) + + def stride(self): + # Hard coded here. Should infer from convs in future + stride = 4 + return stride + + def forward(self, src_tokens, src_lengths, states=None): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = 1.0 * max_seq_len / output_seq_len + input_lengths = torch.max( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long(), + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + # TODO: fix positional embedding + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # State to store memory banks etc. + if states is None: + states = [ + {"memory_banks": None, "encoder_states": None} + for i in range(len(self.transformer_layers)) + ] + + for i, layer in enumerate(self.transformer_layers): + # x size: + # (self.left_size + self.segment_size + self.right_size) + # / self.stride, num_heads, dim + # TODO: Consider mask here + x = layer(x, states[i]) + states[i]["encoder_states"] = x[ + self.left_context_after_stride : -self.right_context_after_stride + ] + + lengths = ( + ( + ~encoder_padding_mask[ + :, self.left_context_after_stride : -self.right_context_after_stride + ] + ) + .sum(dim=1, keepdim=True) + .long() + ) + + return states[-1]["encoder_states"], lengths, states + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryTransformerEncoderLayer +# ------------------------------------------------------------------------------ +class AugmentedMemoryTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, args): + super().__init__(args) + + self.left_context = args.left_context // args.encoder_stride + self.right_context = args.right_context // args.encoder_stride + + def forward(self, x, state): + + length, batch_size, x_dim = x.size() + + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + # init_state + if state.get("memory_banks", None) is None: + state["memory_banks"] = [] + + # TODO reseach new sum_query method + seg_start = self.left_context + seg_end = length - self.right_context + if seg_start < seg_end: + summarization_query = torch.mean(x[seg_start:seg_end], keepdim=True, dim=0) + else: + summarization_query = x.new_zeros(1, batch_size, x_dim) + + x = torch.cat([x, summarization_query], dim=0) + + x = self.self_attn(input_and_summary=x, state=state) + + x = self.dropout_module(x) + x = residual + x + + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + + def build_self_attention(self, embed_dim, args): + return AugmentedMemoryMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + tanh_on_mem=True, + max_memory_size=args.max_memory_size, + ) + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryMultiheadAttention +# ------------------------------------------------------------------------------ +class AugmentedMemoryMultiheadAttention(MultiheadAttention): + """ + Augmented Memory Attention from + Streaming Transformer-based Acoustic Models + Using Self-attention with Augmented Memory + https://arxiv.org/abs/2005.08042 + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + tanh_on_mem=False, + memory_dim=None, + std_scale=0.5, # 0.5 based on https://arxiv.org/abs/2005.09137 + max_memory_size=-1, + disable_mem_on_mem_attn=True, + ): + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + q_noise, + qn_block_size, + ) + + self.memory_dim = memory_dim if memory_dim is not None else embed_dim + self.std_scale = std_scale + self.disable_mem_on_mem_attn = disable_mem_on_mem_attn + + # This Operator was used for factorization in PySpeech + self.v2e = lambda x: x + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = lambda x: x + self.nonlinear_squash_mem = False + + self.max_memory_size = max_memory_size + + def forward(self, input_and_summary, state): + """ + input: Encoder states of current segment with left or right context, + plus one summarization query + + """ + + length, batch_size, _ = input_and_summary.shape + length = length - 1 # not include sum_query, last index + + memory = state["memory_banks"] + # TODO: positional embedding on memory + + if self.max_memory_size > -1 and len(memory) > self.max_memory_size: + # TODO: need to fix here + if self.max_memory_size == 0: + memory = memory.new_zeros(1, memory.size(1), self.memory_dim) + else: + memory = memory[-self.max_memory_size :] + + memory_and_input = torch.cat(memory + [input_and_summary[:-1]], dim=0) + input_and_sum_query = input_and_summary + + q = self.q_proj(self.v2e(input_and_sum_query)) + k = self.k_proj(self.v2e(memory_and_input)) + v = self.v_proj(self.v2e(memory_and_input)) + + q = ( + q.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + * self.scaling + ) + k = ( + k.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + v.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + + if self.disable_mem_on_mem_attn: + attention_weights = self.suppress_mem_on_mem_attention( + batch_size, self.num_heads, len(memory), attention_weights + ) + + if self.std_scale is not None: + attention_weights = attention_suppression(attention_weights, self.std_scale) + + assert list(attention_weights.shape) == [ + batch_size * self.num_heads, + length + 1, + length + len(memory), + ] + + attention_weights = torch.nn.functional.softmax( + attention_weights.float(), dim=-1 + ).type_as(attention_weights) + + attention_probs = self.dropout_module(attention_weights) + + # [T, T, B, n_head] + [T, B, n_head, d_head] -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + + assert list(attention.shape) == [ + batch_size * self.num_heads, + length + 1, + self.head_dim, + ] + + attention = ( + attention.transpose(0, 1) + .contiguous() + .view(length + 1, batch_size, self.embed_dim) + ) + + output_and_memory = self.out_proj(attention) + + next_m = output_and_memory[-1:] + next_m = self.squash_mem(next_m) + output = output_and_memory[:-1] + + state["memory_banks"].append(next_m) + + return output + + def suppress_mem_on_mem_attention( + self, B: int, num_heads: int, mem_size: int, attention_weight: Tensor + ): + """ + Arguments: + - B: batch size + - num_heads: number of attention heads + - mem_size: size of memory bank + - attention_weight: a [B*num_heads, T + 1, T + mem_size] vector + + Return: + modified attention_weight with [B*num_heads, -1, :mem_size] = -inf + """ + attention_weight[:, -1, :mem_size] = float("-inf") + return attention_weight + + +# ------------------------------------------------------------------------------ +# SequenceEncoder +# ------------------------------------------------------------------------------ +class SequenceEncoder(FairseqEncoder): + """ + SequenceEncoder encodes sequences. + + More specifically, `src_tokens` and `src_lengths` in `forward()` should + describe a batch of "complete" sequences rather than segments. + + Segment-by-segment inference can be triggered by `segment_size`: + 1) `segment_size` is None: + SequenceEncoder treats the input sequence as one single segment. + 2) `segment_size` is not None (some int instead): + SequenceEncoder does the following: + 1. breaks the input sequence into several segments + 2. inference on each segment and collect the outputs + 3. concatanete segment outputs into the output sequence. + Note that `segment_size` here shouldn't include additional left/right + contexts needed, for example if we wish to infer with LC-BLSTM where the + middle chunk size is 100 and right context is 20, `segment_size` should be + 100. + """ + + def __init__(self, args, module): + super().__init__(None) + + self.module = module + self.input_time_axis = 1 + self.output_time_axis = 0 + self.segment_size = args.segment_size + self.left_context = args.left_context + self.right_context = args.right_context + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + states=None, + ): + + seg_src_tokens_lengths = sequence_to_segments( + sequence=src_tokens, + time_axis=self.input_time_axis, + lengths=src_lengths, + segment_size=self.segment_size, + extra_left_context=self.left_context, + extra_right_context=self.right_context, + ) + + seg_encoder_states_lengths: List[Tuple[Tensor, Tensor]] = [] + + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + + seg_encoder_states_lengths.append((seg_encoder_states, seg_enc_lengths)) + + encoder_out, enc_lengths = segments_to_sequence( + segments=seg_encoder_states_lengths, time_axis=self.output_time_axis + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + enc_lengths, batch_first=True + ) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + return { + "encoder_out": [encoder_out], + "encoder_padding_mask": [encoder_padding_mask], + "encoder_embedding": [], + "encoder_states": [states], + "src_tokens": [], + "src_lengths": [], + } + + def incremental_encode( + self, + seg_src_tokens: Tensor, + seg_src_lengths: Tensor, + states=None, + ): + """ + Different from forward function, this function takes segmented speech + as input, and append encoder states to previous states + """ + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + return seg_encoder_states, seg_enc_lengths, states + + +# ------------------------------------------------------------------------------ +# Augmented memory model decorator +# ------------------------------------------------------------------------------ +def augmented_memory(klass): + class StreamSeq2SeqModel(klass): + @staticmethod + def add_args(parser): + super(StreamSeq2SeqModel, StreamSeq2SeqModel).add_args(parser) + parser.add_argument( + "--segment-size", type=int, required=True, help="Length of the segment." + ) + parser.add_argument( + "--left-context", + type=int, + default=0, + help="Left context for the segment.", + ) + parser.add_argument( + "--right-context", + type=int, + default=0, + help="Right context for the segment.", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + + StreamSeq2SeqModel.__name__ = klass.__name__ + return StreamSeq2SeqModel diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py new file mode 100644 index 0000000000..e026b86847 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -0,0 +1,1836 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import math +import re +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq.models import ( + FairseqEncoder, +) +from fairseq.models.speech_to_text.utils import ( + NoOp, + lengths_to_padding_mask, + segments_to_sequence, +) +from fairseq.models.speech_to_text.utils import ( + attention_suppression, + layer_norm_backward_hook, +) +from torch import Tensor, device as Device +from torch.quantization.qconfig import ( + default_dynamic_qconfig, + per_channel_dynamic_qconfig, +) + + +class RelativePositionEmbedding(nn.Module): + """ + Implementation according to https://arxiv.org/abs/1803.02155 + """ + + def __init__(self, head_dim, max_position, norm_init=True): + super().__init__() + self.head_dim = head_dim + self.max_position = max_position + self.embeddings = nn.Parameter(torch.Tensor(max_position * 2 + 1, head_dim)) + if norm_init: + nn.init.xavier_normal_(self.embeddings) + else: + nn.init.xavier_uniform_(self.embeddings) + + def forward(self, input: Tensor): + output = nn.functional.embedding(input.long(), self.embeddings) + return output + + +class Fp32LayerNorm(nn.Module): + def __init__( + self, + input_dim, + clamp_grad=True, + max_grad_value=256, + eps=1e-5, + elementwise_affine=True, + ): + super().__init__() + self.torch_module = torch.nn.LayerNorm( + input_dim, eps=eps, elementwise_affine=elementwise_affine + ) + if clamp_grad: + hook = partial(layer_norm_backward_hook, clamp_value=max_grad_value) + self.torch_module.register_backward_hook(hook) + + def forward(self, input): + output = torch.nn.functional.layer_norm( + input.float(), + self.torch_module.normalized_shape, + self.torch_module.weight.float() + if self.torch_module.weight is not None + else None, + self.torch_module.bias.float() + if self.torch_module.bias is not None + else None, + self.torch_module.eps, + ).type_as(input) + return output + + +# ------------------------------------------------------------------------------ +# PositionwiseFF +# ------------------------------------------------------------------------------ + + +class PositionwiseFF(nn.Module): + """ + FFN layer in transformer. + + Args: + input_dim: input embedding dimension + ffn_dim: FFN layer inner dimension + dropout_on_fc1: dropout for first linear layer + dropout_on_fc2: dropout fr second linear layer + activation_fn: activation function used after first linear layer. \ + Only relu or gelu is supported. + + """ + + def __init__( + self, input_dim, ffn_dim, dropout_on_fc1, dropout_on_fc2, activation_fn + ): + super(PositionwiseFF, self).__init__() + + self.input_dim = input_dim + self.ffn_dim = ffn_dim + if activation_fn == "relu": + ac = nn.ReLU() + elif activation_fn == "gelu": + ac = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(activation_fn)) + + # fc1 -> ac -> dropout -> fc2 -> dropout + self.module = nn.Sequential( + nn.Linear(input_dim, ffn_dim), + ac, + nn.Dropout(dropout_on_fc1), + nn.Linear(ffn_dim, input_dim), + nn.Dropout(dropout_on_fc2), + ) + + self.layer_norm = Fp32LayerNorm(input_dim) + + def forward(self, input): + module_out = self.module(self.layer_norm(input)) + output = module_out + input + + return output + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# SummarizationLayer +# ------------------------------------------------------------------------------ + + +class SummarizationLayer(nn.Module): + def __init__(self, method, segment_size, embedding_dim): + super(SummarizationLayer, self).__init__() + self.segment_size = segment_size + self.embedding_dim = embedding_dim + nonlin_match = re.match(r"nonlinear\((?P[a-z]+),(?P[0-9]+)\)", method) + self.method = method + if method == "mean": + self.module = nn.AvgPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "max": + self.module = nn.MaxPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "linear": + self.module = nn.Linear(segment_size, 1) + elif nonlin_match: + nonlin_args = nonlin_match.groupdict() + act_type = nonlin_args["act"] + hid_dim = int(nonlin_args["dim"]) + if act_type == "relu": + act = nn.ReLU() + elif act_type == "gelu": + act = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(act_type)) + self.module = nn.Sequential( + nn.Linear(segment_size, hid_dim), + act, + nn.Linear(hid_dim, 1), + ) + else: + raise ValueError("Unsupported summarization method = ({})".format(method)) + + def forward(self, input): + # T, B, D -> B, D, T + input = input.permute(1, 2, 0) + + if self.method == "mean" or self.method == "max": + output = self.module(input) + output = output.permute(2, 0, 1) + return output + + full_seg_length = input.size(2) // self.segment_size * self.segment_size + if full_seg_length > 0: + # at least one seg is full + B = input.size(0) + D = input.size(1) + input_todo = ( + input[:, :, :full_seg_length] + .contiguous() + .view(B, -1, self.segment_size) + ) + output = self.module(input_todo) + output = output.view(B, D, -1) + else: + output = input.new_zeros(input.size(0), input.size(1), 0) + left = input.size(2) - full_seg_length + if left > 0: + # when last seg is not full, use zeros as last memory placeholder + zeros = input.new_zeros(input.size(0), input.size(1), 1) + output = torch.cat([output, zeros], dim=2) + output = output.permute(2, 0, 1) + return output + + +# ------------------------------------------------------------------------------ +# NoSegAugmentedMemoryMultiheadAttentionBmm +# ------------------------------------------------------------------------------ + + +class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module): + """ + Whole utterance augmented memory multihead attention using BMM. + + Different with previous augmented memory multihead attention where + the utterance is chunked into segments. Here we use attention mask + achieve so. The input embedding [right_context, utterance, summary] + is a concatenation of right context, utterance and summary. + + Right context block is the concatenation of all the right context for + each segments. [right_context_0, right_context_1, ..., right_context_n] + For example, if we have utterance = [v0, v1, v2, ...., v20]. segment + size 8, right_context size 4. Then the right context blocks = + [v8, v9, v10, v11, v16, v17, v18, v19, 0, 0, 0, 0], where v8, v9, v10, + and v11 are the right context for first segment. v16, v17, v18 and v19 + are the right context for second segment. 0, 0, 0 and 0 are right context + for the last segment. + + utterance is corresponding to input embedding sequence + + summary is concatenation of average of each segments. [summary_0, + summary_1, ..., ]. + + In augmented memory multihead attention, the query is [right_context, + utterance, summary], key is [memory, right_context, utterance]. Different + with AugmentedMemoryMultiheadAttentionBmm, memory here is passed from + previous attention layer. For the first attention layer, memory is average + of each segment. + + Memory is a concatenation of memory from each segments in previous attention + layer. For example, current layer is i, then memory is [m_0, m_1, ..., m_n]. + Each m_k is the output from seg_k in layer i-1. + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + dropout: attention dropout + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + scaled_init: whether to use scaled init for linear weight + tanh_on_mem: whether to use tanh on memory output + use_mem: whether to use memory or not. When max_memory_size is 0, then + we don't have memory anymore. + layer_index: current self-attention layer index that is used in depth + initialization + max_relative_position: max relative position used in relative position + embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + + """ + + def __init__( + self, + input_dim, + num_heads, + dropout=0.0, + std_scale=None, + scaled_init=False, + tanh_on_mem=False, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + max_relative_position=0, + rpe_old_option=True, + ): + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + super().__init__() + + embed_dim = input_dim + self.e2h_kv = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) + self.e2h_q = torch.nn.Linear(input_dim, input_dim, bias=True) + self.rpe_old_option = rpe_old_option + if max_relative_position > 0: + self.use_rpe = True + self.rpe_k = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + self.rpe_v = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + else: + self.use_rpe = False + self.rpe_k = None + self.rpe_v = None + if scaled_init: + if layer_index == -1: + gain = 1.0 / math.sqrt(2) + else: + # https://arxiv.org/abs/2005.09684 depthwise initialization + # stablize the training greatly. Use depthwise initialization to + # replace incremental loss. + gain = 1.0 / math.sqrt(layer_index + 1) + torch.nn.init.xavier_uniform_(self.e2h_kv.weight, gain=gain) + torch.nn.init.xavier_uniform_(self.e2h_q.weight, gain=gain) + + self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True) + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim ** -0.5 + + self.std_scale = std_scale + self.use_mem = use_mem + self.mini_batches = mini_batches + self.negative_inf = negative_inf + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = NoOp() + self.nonlinear_squash_mem = False + + def prepare_qkv( + self, + input: Tensor, + mems: Tensor, + lengths: Tensor, + summary_length: int, + lc_length: int, + ): + # T: right_context length + utterance_length + summary_length + T, B, D = input.shape + mem_length = mems.size(0) + utterance_length = torch.max(lengths) + + right_context_blocks_length = T - utterance_length - summary_length + rc_block = input[:right_context_blocks_length, :, :] + utterance_block = input[right_context_blocks_length : T - summary_length, :, :] + + if B == 1: + padding_mask = None + else: + klengths = lengths + mem_length + right_context_blocks_length + lc_length + padding_mask = lengths_to_padding_mask(lengths=klengths) + + mem_rc_input = torch.cat([mems, rc_block, utterance_block], dim=0) + + # In training lc_length = 0 + key_length = mem_rc_input.size(0) + lc_length + rc_input_sum = input + q = self.e2h_q(rc_input_sum) + kv = self.e2h_kv(mem_rc_input) + k, v = kv.chunk(chunks=2, dim=2) + result_qkv = (q, k, v) + input_shape = (T, B, D) + result_lengths_info = ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) + if padding_mask is not None: + assert padding_mask.size(0) == B + assert padding_mask.size(1) == key_length + + return result_qkv, input_shape, result_lengths_info, padding_mask + + def prepare_attention_weights( + self, + q: Tensor, + new_k: Tensor, + new_v: Tensor, + input_shape: Tuple[int, int, int], + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + T, B, D = input_shape + q = ( + q.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) + * self.scaling + ) + + k = ( + new_k.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + new_v.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_k = self.rpe_k(rpe) + # [q, B*h, d] * [q, k, d] -> [B*h, q, k] + attention_weights_rpe = torch.matmul( + q.transpose(0, 1), r_k.transpose(1, 2) + ).transpose(0, 1) + attention_weights = attention_weights + attention_weights_rpe + attention_weights_float = attention_weights.float() + + return attention_weights, attention_weights_float, v + + def prepare_attention_output( + self, + attention_weights: Tensor, + attention_weights_float: Tensor, + v: Tensor, + input_shape: Tuple[int, int, int], + key_length: int, + padding_mask: Optional[Tensor], + rpe: Optional[Tensor], + ) -> Tensor: + T, B, D = input_shape + if padding_mask is not None: + attention_weights_float = attention_weights_float.view( + B, self.num_heads, T, key_length + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attention_weights_float = attention_weights_float.view( + B * self.num_heads, T, key_length + ) + + if self.std_scale is not None: + attention_weights_float = attention_suppression( + attention_weights_float, self.std_scale + ) + + attention_weights_float = torch.nn.functional.softmax( + attention_weights_float, dim=-1 + ) + attention_weights = attention_weights_float.type_as(attention_weights) + + attention_probs = torch.nn.functional.dropout( + attention_weights, p=self.dropout, training=self.training + ) + + # [T, key_length, B, n_head]+ [key_length, B, n_head, d_head] + # -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_v = self.rpe_v(rpe) + attention_rpe = torch.matmul( + attention_probs.transpose(0, 1), r_v + ).transpose(0, 1) + + if self.rpe_old_option: + attention += attention + attention_rpe + else: + attention = attention + attention_rpe + + assert list(attention.shape) == [B * self.num_heads, T, self.head_dim] + + attention = attention.transpose(0, 1).contiguous().view(T, B, self.embed_dim) + + rc_output_memory = self.out_proj(attention) + return rc_output_memory + + @torch.jit.unused + def forward( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + attention_mask: Tensor, + pre_mems: Optional[Tensor] = None, + left_context_key: Optional[Tensor] = None, + left_context_val: Optional[Tensor] = None, + rpe: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in training. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + attention_mask: attention mask for query = [right_context, query, summary] + key = [mem, right_context, query]. This is only used for traing. + + """ + if self.use_mem: + mem_length = mems.size(0) + summary_length = mem_length + 1 + if pre_mems is not None: + mems = torch.cat([pre_mems, mems], dim=0) + else: + mem_length = 0 + summary_length = 0 + + # In training, lc_length = 0 + if left_context_key is not None: + lc_length = left_context_key.size(0) + else: + lc_length = 0 + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + if left_context_key is not None: + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + else: + new_k = k + new_v = v + next_k = None + next_v = None + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + + # mask attention + attention_mask = attention_mask.unsqueeze(0) + attention_weights_float = attention_weights_float.masked_fill( + attention_mask, float(self.negative_inf) + ) + + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + if self.use_mem: + # next_m length equals to summary length - 1 + # last memory is ignored + if self.mini_batches: + next_m = rc_output_memory[-summary_length:] + else: + next_m = rc_output_memory[-summary_length:-1] + + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-summary_length] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + next_m = mems + rc_output = rc_output_memory + + return rc_output, next_m, next_k, next_v + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in decoding. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + left_context_key: left_context for key part. This is only used for online + decoding. In training, this is empty tensor + left_context_val: left_context for value part. This is only used for online + decoding. In training, this is empty tensor + + """ + lc_length = left_context_key.size(0) + + # In decoding, summary_length = 1 or 0 + if self.use_mem: + summary_length = 1 + else: + summary_length = 0 + + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + # In online decoding, we don't have attention mask. But we still need + # to disable the attention from summary query to memory + attention_weights_float[:, -1, :mem_length] = float(self.negative_inf) + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + # In decoding, summary length is 1 + if self.use_mem: + next_m = rc_output_memory[-1:] + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-1] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + rc_output = rc_output_memory + # empty tensor as input mems + next_m = mems + + return rc_output, next_m, next_k, next_v + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +class NoSegAugmentedMemoryTransformer(nn.Module): + """ + Whole utterance augmented memory transformer. + + This is not pyspeech nn layer. It is used as a module in a master layer where + multiple transformers is used. + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + dropout_in_attn=0.0, + dropout_on_attn=None, + dropout_on_fc1=None, + dropout_on_fc2=None, + activation_fn="relu", + tanh_on_mem=False, + std_scale=None, + scaled_init=False, + segment_size=128, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super(NoSegAugmentedMemoryTransformer, self).__init__() + + self.attention = NoSegAugmentedMemoryMultiheadAttentionBmm( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout_in_attn, + scaled_init=scaled_init, + tanh_on_mem=tanh_on_mem, + std_scale=std_scale, + use_mem=use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + max_relative_position=max_relative_position, + ) + self.dropout = nn.Dropout(dropout_on_attn) + self.pos_ff = PositionwiseFF( + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + activation_fn=activation_fn, + ) + self.layer_norm_pre = Fp32LayerNorm(input_dim) + self.layer_norm = Fp32LayerNorm(input_dim) + self.segment_size = segment_size + self.use_mem = use_mem + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + def set_mini_batches(self, mini_batches): + self.attention.mini_batches = mini_batches + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def pre_attention_ops(self, input, right_context_blocks): + rc_length = right_context_blocks.size(0) + input_length = input.size(0) + + rc_and_input = torch.cat([right_context_blocks, input], dim=0) + residual_input = rc_and_input + rc_and_input = self.layer_norm_pre(rc_and_input) + + query_input = rc_and_input[-input_length:, :, :] + return rc_length, input_length, residual_input, query_input, rc_and_input + + def after_attention_ops(self, attention_output, residual_input): + output = self.dropout(attention_output) + output = output + residual_input + output = self.pos_ff(output) + output = self.layer_norm(output) + return output + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + right_context_blocks: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + + # In online decoding, the summary query size is always 1 or 0 + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + summary_query = summary_query[0:1, :, :] + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention.forward_jit( + input=rc_qu_su, + lengths=lengths, + mems=mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + return results + + @torch.jit.unused + def forward( + self, + input, + lengths, + mems, + right_context_blocks, + attention_mask, + pre_mems, + left_context_key, + left_context_val, + rpe, + ): + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention( + input=rc_qu_su, + lengths=lengths, + mems=mems, + attention_mask=attention_mask, + pre_mems=pre_mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + + # [TODO] Note memory did not go through pos_ff. What happen if we pass + # memory through the pos_ff as well? + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + + return results + + +class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder): + """ + Whole utterance augmented memory transformer encoder layer. This is a master layer + where we can define multiple augmented memory transformers. There are two reasons + to setup the master layer. + 1. We only need to define once about the attention mask. All the layers in the master + layer share the same mask. + 2. pyspeech nn layer has special input and output format. Defining one master layer is + easier to passing memory between different layes inside the master layer + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + ffn_dim: ffn dimension in FFN layer + num_layers: number of augmented memory transformer layers + dropout_in_attn: dropout used in multi-head self-attention + dropout_on_attn: dropout used for output from te multihead self-attention + dropout_on_fc1: dropout used in FFN layer for the first linear layer + dropout_on_fc2: dropout used in FFN layer for the second linear layer + segment_size: segment size for each segment + context_config: (left_context_size, right_context_size) defines the surround context size + for each segment + max_memory_size: maximum memory size used for each segment + scaled_init: whether use scaled init for weight initialization in attention layer + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + activation_fn: activation function used in FFN layer. [ReLU, GELU] supported + tanh_on_mem: whether use tanh on memory + mini_batches: use mini-btach training + negative_inf: the negative infinity value used in attention masking. default is "-inf". + For some situation, e.g. LM. it is better to use "-1e8" to avoid nan issue. + summarization_method: method to generate segment summrization embedding + max_relative_position: max relatie position for relative position embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + [TODO]: remove the rpe_old_option by the end of 2021 Q1. + + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + num_layers=1, + dropout_in_attn=0.0, + dropout_on_attn=0.0, + dropout_on_fc1=0.0, + dropout_on_fc2=0.0, + segment_size=128, + context_config=(0, 0), + max_memory_size=0, + scaled_init=True, + std_scale=None, + activation_fn="relu", + tanh_on_mem=False, + mini_batches=False, + negative_inf="-inf", + deep_init=True, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super().__init__(None) + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + # we used to support growing memory size. However, it will cause + # cross stream batching failure. Now we need to have exact max memory size + if max_memory_size < 0: + raise ValueError("max_memory_size must be >= 0") + + # Only assign right_context. In decoding, left context will be cached. + # No need to let the online decoder to re-assign the left context + self.left_context, self.right_context = context_config + self.segment_size = segment_size + self.memory_dim = input_dim + self.max_memory_size = max_memory_size + self.mini_batches = mini_batches + if self.max_memory_size != 0: + self.use_mem = True + else: + self.use_mem = False + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + self.layers = torch.nn.ModuleList() + self.num_layers = num_layers + self.max_relative_position = max_relative_position + if self.max_relative_position > 0: + self.use_rpe = True + else: + self.use_rpe = False + for i in range(self.num_layers): + if deep_init: + layer_index = i + else: + layer_index = -1 + + self.layers.append( + NoSegAugmentedMemoryTransformer( + num_heads=num_heads, + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_in_attn=dropout_in_attn, + dropout_on_attn=dropout_on_attn, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + segment_size=segment_size, + std_scale=std_scale, + activation_fn=activation_fn, + tanh_on_mem=tanh_on_mem, + scaled_init=scaled_init, + use_mem=self.use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + summarization_method=summarization_method, + max_relative_position=max_relative_position, + rpe_old_option=rpe_old_option, + ) + ) + + def set_mini_batches(self, mini_batches): + # handy function only used for unit test + self.mini_batches = mini_batches + for layer in self.layers: + layer.set_mini_batches(mini_batches) + + def _get_relative_position( + self, + input: Tensor, + max_relative_position: int, + left_context_length: int, + past_length: int, + is_decoding: bool, + ): + # For training, we copy the right context to the start of the utterance + # First dimension in distance is corresponding to query. + # [right context, utterance, summary vector] + # Second dimension in distance is corresponding to key. + # [Memory bank, right context, utterance] + # For summary vector in query part, the distance with + # all other position is 2*max_position. For memory bank in key, + # the distance with all other positions is 0. + + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + + # utterance + u_st = past_length * self.segment_size + u_ed = u_st + T + utterance_ranges = torch.arange(u_st, u_ed - self.right_context) + + # left context. Only in minibatch or decoding + left_context_ranges = torch.arange(u_st - left_context_length, u_st) + + # Right context block + # right context + utterance + right_context_blocks = [] + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + u_st + ed = st + self.right_context + assert ed < u_ed + temp = torch.arange(st, ed) + right_context_blocks.append(temp) + right_context_blocks.append(torch.arange(u_ed - self.right_context, u_ed)) + right_context_ranges = torch.cat(right_context_blocks) + + if self.use_mem: + # Memory bank + # The position for memory -n, .., -1 + if is_decoding: + memory_size = min(past_length, self.max_memory_size) + else: + memory_size = num_segs + past_length - 1 + memory_bank_ranges = torch.arange( + -max_relative_position - 1, -max_relative_position - 1 - memory_size, -1 + ) + + # summary vector + # The position for summary vector as the T+max_relative_position+1. + # After the clamping, the relative position is max_relative_position + summary_pos_st = u_ed + max_relative_position + 1 + summary_vector_ranges = torch.arange( + summary_pos_st, summary_pos_st + num_segs + ) + + key_ranges = torch.cat( + [ + memory_bank_ranges, + right_context_ranges, + left_context_ranges, + utterance_ranges, + ] + ) + + query_ranges = torch.cat( + [right_context_ranges, utterance_ranges, summary_vector_ranges] + ) + else: + key_ranges = torch.cat( + [right_context_ranges, left_context_ranges, utterance_ranges] + ) + + query_ranges = torch.cat([right_context_ranges, utterance_ranges]) + + distance = key_ranges[None, :] - query_ranges[:, None] + distance_clamp = ( + torch.clamp(distance, -max_relative_position, max_relative_position) + + max_relative_position + ) + distance_clamp = distance_clamp.to(input.device).long().detach() + return distance_clamp + + def _get_attention_mask(self, input, past_length=0, left_context_cache=0): + # attention mask for each query contains three parts: + # 1. memory part + # 2. left_context + segment + # 3. right_context_block + # so for each segment and its correspoinding right context block, + # the attention matrix is formed by 9 parts: + # [0, m, 0, 0, right_context, 0, 0, seg, 0] + # [before memory, memory, after memory, before right context, right_context, + # after right context, before seg, seg, after seg] + # + # Query is formed in the way as [right_context_blocks, utterance, summary] + # + # Note: put m and right_context before segment is convenient + # for padding_mask operation. + # Key lengths = m_length + right_context_block_length + lengths + utterance_length, batch_size, _ = input.shape + summary_length = math.ceil(utterance_length / self.segment_size) + num_segs = summary_length + rc_length = self.right_context * num_segs + rc = self.right_context + lc = self.left_context + + # using mini-batches, there is left context cache available for current + # sequence. + lcc = left_context_cache + + # max_memory_size is 0 then we don't have memory and summary + # past_length is the memory carry from previous sequence + if self.use_mem: + mem_length = num_segs - 1 + past_length + else: + mem_length = 0 + rc_mask = [] + query_mask = [] + summary_mask = [] + for j in range(0, num_segs): + ssize = min(self.segment_size, utterance_length - j * self.segment_size) + + rc_size = rc + rc_mat = [] + q_mat = [] + s_mat = [] + m_start = max(j + past_length - self.max_memory_size, 0) + + # max_memory_size is 0, then we don't use memory + if self.use_mem: + # part 0: before memory + rc_mat.append(input.new_zeros(rc_size, m_start)) + q_mat.append(input.new_zeros(ssize, m_start)) + s_mat.append(input.new_zeros(1, m_start)) + + # part 1: memory + col_1 = j + past_length - m_start + rc_mat.append(torch.ones(rc_size, col_1, device=input.device)) + q_mat.append(torch.ones(ssize, col_1, device=input.device)) + # based on D22875746, disable summary query attention + # on memeory is better for long form utterance + s_mat.append(input.new_zeros(1, col_1)) + + # part 2: after memory + col_2 = mem_length - (j + past_length) + rc_mat.append(input.new_zeros(rc_size, col_2)) + q_mat.append(input.new_zeros(ssize, col_2)) + s_mat.append(input.new_zeros(1, col_2)) + + # part 3: before right context + rc_start = j * rc + rc_mat.append(input.new_zeros(rc_size, rc_start)) + q_mat.append(input.new_zeros(ssize, rc_start)) + s_mat.append(input.new_zeros(1, rc_start)) + + # part 4: right context + rc_end = rc_start + rc + col_4 = rc + rc_mat.append(torch.ones(rc_size, col_4, device=input.device)) + q_mat.append(torch.ones(ssize, col_4, device=input.device)) + s_mat.append(torch.ones(1, col_4, device=input.device)) + + # part 5: after right context + col_5 = rc_length - rc_end + rc_mat.append(input.new_zeros(rc_size, col_5)) + q_mat.append(input.new_zeros(ssize, col_5)) + s_mat.append(input.new_zeros(1, col_5)) + + # part 6: before query segment + seg_start = max(j * self.segment_size + lcc - lc, 0) + rc_mat.append(input.new_zeros(rc_size, seg_start)) + q_mat.append(input.new_zeros(ssize, seg_start)) + s_mat.append(input.new_zeros(1, seg_start)) + + # part 7: query segment + # note: right context is put in right context block + # here we only need to consider about left context + seg_end = min((j + 1) * self.segment_size + lcc, utterance_length + lcc) + col_7 = seg_end - seg_start + rc_mat.append(torch.ones(rc_size, col_7, device=input.device)) + q_mat.append(torch.ones(ssize, col_7, device=input.device)) + s_mat.append(torch.ones(1, col_7, device=input.device)) + + # part 8: after query segment + col_8 = utterance_length + lcc - seg_end + rc_mat.append(input.new_zeros(rc_size, col_8)) + q_mat.append(input.new_zeros(ssize, col_8)) + s_mat.append(input.new_zeros(1, col_8)) + + rc_mask.append(torch.cat(rc_mat, dim=1)) + query_mask.append(torch.cat(q_mat, dim=1)) + summary_mask.append(torch.cat(s_mat, dim=1)) + + # no memory, then we don't need summary either + if self.use_mem: + attention_mask = ( + 1 + - torch.cat( + [ + torch.cat(rc_mask, dim=0), + torch.cat(query_mask, dim=0), + torch.cat(summary_mask, dim=0), + ], + dim=0, + ) + ).to(torch.bool) + else: + attention_mask = ( + 1 + - torch.cat( + [torch.cat(rc_mask, dim=0), torch.cat(query_mask, dim=0)], dim=0 + ) + ).to(torch.bool) + + return attention_mask + + @torch.jit.export + def init_state( + self, batch_size: int, device: Optional[Device] = None + ) -> List[Tensor]: + empty_memory = torch.zeros( + self.num_layers, + self.max_memory_size, + batch_size, + self.memory_dim, + device=device, + ) + left_context_key = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + left_context_val = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + + return [empty_memory, left_context_key, left_context_val, past_length] + + @torch.jit.export + def batch_state(self, states: List[List[Tensor]]) -> List[Tensor]: + if len(states) == 0: + return [] + batched_m = [] + batched_lc_key = [] + batched_lc_val = [] + batched_past_length = [] + for state in states: + if len(state) == 0: + continue + m, lc_key, lc_val, past_length = state + batched_m.append(m) + batched_lc_key.append(lc_key) + batched_lc_val.append(lc_val) + batched_past_length.append(past_length) + + if ( + (len(batched_m) == 0) + or (len(batched_lc_key) == 0) + or (len(batched_lc_val) == 0) + or (len(batched_past_length) == 0) + ): + return [ + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ] + + batched_m = torch.cat(batched_m, dim=2) + batched_lc_key = torch.cat(batched_lc_key, dim=2) + batched_lc_val = torch.cat(batched_lc_val, dim=2) + batched_past_length = torch.cat(batched_past_length, dim=1) + return [batched_m, batched_lc_key, batched_lc_val, batched_past_length] + + @torch.jit.export + def reorder_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + if len(state) == 0: + return [] + m, lc_key, lc_val, past_length = state + indices = indices.to(device=m.device) + reord_m = torch.index_select(m, 2, indices) + reord_lc_key = torch.index_select(lc_key, 2, indices) + reord_lc_val = torch.index_select(lc_val, 2, indices) + reord_past_length = torch.index_select(past_length, 1, indices) + return [reord_m, reord_lc_key, reord_lc_val, reord_past_length] + + @torch.jit.export + def reset_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + m, lc_key, lc_val, past_length = state + m = m.index_fill(dim=2, index=indices, value=0.0) + lc_key = lc_key.index_fill(dim=2, index=indices, value=0.0) + lc_val = lc_val.index_fill(dim=2, index=indices, value=0.0) + past_length = past_length.index_fill(dim=1, index=indices, value=0) + + return [m, lc_key, lc_val, past_length] + + @torch.jit.export + def state_size(self) -> int: + return 4 + + @torch.jit.export + def batch_size_in_state( + self, state: Optional[List[Tensor]], sloppy: bool = True + ) -> Optional[int]: + if state is None: + return None + return state[0].size(2) + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def _gen_right_context_padded_input(self, input): + # This function deals with input that is already + # padded with right context (e.g. minibatch training) + right_context_blocks = [] + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + ed = st + self.right_context + assert ed < T + temp = input[st:ed, :, :] + right_context_blocks.append(temp) + + # last segment right context is already available + right_context_blocks.append(input[T - self.right_context :, :, :]) + return torch.cat(right_context_blocks, dim=0) + + def _gen_segs_right_context(self, input, lengths): + segments = [] + T, B, D = input.size() + nT = T - self.right_context + + # assume input is right context padded + num_segs = math.ceil(nT / self.segment_size) + # pad zeros to the utterance to make sure each + # segment has the same right context. For the + for i in range(0, num_segs - 1): + st = i * self.segment_size + ed = min(T, st + self.segment_size + self.right_context) + temp = input[st:ed, :, :] + rest_lengths = torch.clamp( + lengths - self.segment_size, min=0, max=nT - (i + 1) * self.segment_size + ) + segments.append((temp, lengths - rest_lengths + self.right_context)) + lengths = rest_lengths + + last_seg = input[st + self.segment_size :, :, :] + segments.append((last_seg, rest_lengths + self.right_context)) + + return segments + + @torch.jit.unused + def forward( + self, input: Tensor, padding_masks: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + # Xutai: originally the second argument is lengths. + lengths = (~padding_masks).sum(dim=1).long() + # mini batch training. + if self.mini_batches: + return self.forward_mini_batches(input, lengths, state) + + # regular full sequence training. Note, assume the right context in provided + # in the input. + T, B, D = input.size() + right_context_blocks = self._gen_right_context_padded_input(input) + + # generate the relative positional embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=0, + past_length=0, + is_decoding=False, + ) + else: + rpe = None + input = input[: T - self.right_context, :, :] + + attention_mask = self._get_attention_mask(input) + + # firt layer use each segment mean as memory + # ignore the last one seg average + if self.use_mem: + mems = self.gen_summary_queries(input)[:-1, :, :] + else: + mems = torch.zeros(0, input.size(1), input.size(2), device=input.device) + mems = mems.type_as(input) + + output = input + all_outputs = [] + + for layer in self.layers: + output, mems, right_context_blocks, _, _ = layer( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=None, + left_context_key=None, + left_context_val=None, + rpe=rpe, + ) + all_outputs.append(output) + return output, padding_masks, [], all_outputs + + def forward_jit_mini_batch_init( + self, + seg: Tensor, + state: Optional[List[Tensor]] = None, + is_decoding: bool = False, + ): + # Prepare state. In whole sequence training, state is ignored. + # For minibatch training, we need to prepare state + if state is None: + state = self.init_state(batch_size=seg.size(1), device=seg.device) + if seg.dtype == torch.half: + state = [state[0].half(), state[1].half(), state[2].half(), state[3]] + + if self.use_mem: + # note input average only on seg, not on right context + # first layer use each segmetn mean as memory. the last + # one segment average is used in state + full_mems = self.gen_summary_queries(seg) + if is_decoding: + mems = full_mems[0:1, :, :] + state_mems = torch.cat([state[0][0], mems], dim=0) + else: + mems = full_mems[:-1, :, :] + state_mems = torch.cat([state[0][0], full_mems], dim=0) + else: + mems = state[0][0] + state_mems = mems + + # track processed segment number or memory number + # the same batch as the same bumber of past length + past_length = state[3][0][0].item() + past_left_context = min(past_length * self.segment_size, self.left_context) + past_length = min(self.max_memory_size, past_length) + + return state, mems, state_mems, past_length, past_left_context + + def state_update_before( + self, layer: int, state: List[Tensor], past_length: int, past_left_context: int + ): + pre_mems = state[0][layer][self.max_memory_size - past_length :, :, :] + lc_key = state[1][layer][self.left_context - past_left_context :, :, :] + lc_val = state[2][layer][self.left_context - past_left_context :, :, :] + return pre_mems, lc_key, lc_val + + def state_update_after( + self, + layer: int, + state: List[Tensor], + mems: Tensor, + next_key: Tensor, + next_val: Tensor, + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + ): + # mems is used for next layer + if layer < self.num_layers - 1: + state_mems = torch.cat([state[0][layer + 1], mems], dim=0) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + + # when mems pass to next sequence, we need the last memory. when mems + # use for the next layer, we can ignore the last memory + mems = mems[:-1, :, :] + + # note state[1][i] and state[2][i] original length equals to self.left_context + new_k = torch.cat([state[1][layer], next_key], dim=0) + new_v = torch.cat([state[2][layer], next_val], dim=0) + lc_key_list.append(new_k[-self.left_context :, :, :]) + lc_val_list.append(new_v[-self.left_context :, :, :]) + return mems_list, lc_key_list, lc_val_list, mems + + def state_update_after_loop( + self, + state: List[Tensor], + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + update_length: int, + ): + state[0] = torch.stack(mems_list, dim=0) + state[1] = torch.stack(lc_key_list, dim=0) + state[2] = torch.stack(lc_val_list, dim=0) + state[3] = state[3] + update_length + return state + + @torch.jit.unused + def forward_mini_batches( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + T, B, D = input.size() + + # input without right context + seg = input[: T - self.right_context, :, :] + + # get right context blocks + right_context_blocks = self._gen_right_context_padded_input(input) + + mems_list = [] + lc_key_list = [] + lc_val_list = [] + results = self.forward_jit_mini_batch_init(seg, state, False) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=False, + ) + else: + rpe = None + + # get attention mask based on seg (not include right context) and available + # left context + attention_mask = self._get_attention_mask(seg, past_length, past_left_context) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + all_outputs = [] + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + pre_mems, lc_key, lc_val = self.state_update_before( + i, state, past_length, past_left_context + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + all_outputs.append(output) + mems_list, lc_key_list, lc_val_list, mems = self.state_update_after( + layer=i, + state=state, + mems=mems, + next_key=next_key, + next_val=next_val, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + + i += 1 + + # update state + update_length = math.ceil((T - self.right_context) / self.segment_size) + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=update_length, + ) + + return output, lengths, state, all_outputs + + def forward_jit_test( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + This one simulate sequence encoder forward jit. This is for unit test purpose. + It is not used in training or decoding. Note, extra_right_context is set in + the model. In unit test, input = [utterance, right_context], lengths = + [utterance_length]. + args: + input: input utterance + lengths: utterance input length + state: None here. input is whole utterance + """ + # [TODO] sequence_to_segment has bug in lengths. + seg_src_tokens_lengths = self._gen_segs_right_context(input, lengths) + + seg_enc_tokens_lengths: List[Tuple[Tensor, Tensor]] = [] + state: Optional[List[Tensor]] = None + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + seg_enc_tokens, seg_enc_lengths, state = self.forward_jit( + input=seg_src_tokens, lengths=seg_src_lengths, state=state + ) + seg_enc_tokens_lengths.append((seg_enc_tokens, seg_enc_lengths)) + + enc_tokens, enc_lengths = segments_to_sequence( + segments=seg_enc_tokens_lengths, time_axis=0 + ) + + state = [] # returns trivial state + + return enc_tokens, enc_lengths, state + + @torch.jit.export + def forward_jit( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Forward helper for online decoding. + + args: + input: [seg, right_context]. We assume in online we + always padding the right context to the preset right context size. + For the last segment, we may have short segment size, but right + context size is the same as other segments + lengths: utterance input length is the utterance segment length and + right context size + state: [memory, left_context_key, left_context_val]. To improve throughput, + in addition to memory, we also cache key and value for left_context in + multihead self-attention + """ + # In online decoding, input = [segment, right_context] + # Lengths = [segment_length, right_context_length] + # so we need strip right context in output + T, B, D = input.size() + rc_str = T - self.right_context + rc_end = T + right_context_blocks = input[rc_str:rc_end, :, :] + seg = input[:rc_str, :, :] + lengths = torch.clamp(lengths - self.right_context, min=0) + mems_list = [] + lc_key_list = [] + lc_val_list = [] + + results = self.forward_jit_mini_batch_init(seg, state, True) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=True, + ) + else: + rpe = None + + # memory for first layer. + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + true_mems, lc_key, lc_val = self.state_update_before( + layer=i, + state=state, + past_length=past_length, + past_left_context=past_left_context, + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward_jit( + input=output, + lengths=lengths, + mems=true_mems, + right_context_blocks=right_context_blocks, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + # mems is used for next layer + mems_list, lc_key_list, lc_val_list, _ = self.state_update_after( + layer=i, + state=state, + mems_list=mems_list, + mems=mems, + next_key=next_key, + next_val=next_val, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + i += 1 + + # update state + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=1, + ) + + return output, lengths, state + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + torch.quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# Emformer encoder for seq2seq model +# This is a wrapper over the original emformer +# ------------------------------------------------------------------------------ +def emformer_encoder(klass): + class SpeechEncoder(klass): + def __init__(self, args): + super().__init__(args) + stride = SpeechEncoder.conv_layer_stride(args) + trf_left_context = args.segment_left_context // stride + trf_right_context = args.segment_right_context // stride + context_config = [trf_left_context, trf_right_context] + self.transformer_layers = nn.ModuleList( + [ + NoSegAugmentedMemoryTransformerEncoderLayer( + input_dim=args.encoder_embed_dim, + num_heads=args.encoder_attention_heads, + ffn_dim=args.encoder_ffn_embed_dim, + num_layers=args.encoder_layers, + dropout_in_attn=args.dropout, + dropout_on_attn=args.dropout, + dropout_on_fc1=args.dropout, + dropout_on_fc2=args.dropout, + activation_fn=args.activation_fn, + context_config=context_config, + segment_size=args.segment_length, + max_memory_size=args.max_memory_size, + scaled_init=True, # TODO: use constant for now. + tanh_on_mem=args.amtrf_tanh_on_mem, + ) + ] + ) + + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + (output, encoder_padding_masks, [], _) = encoder_out["encoder_out"][0] + + # This is because that in the original implementation + # the output didn't consider the last segment as right context. + encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] + + return { + "encoder_out": [output], + "encoder_padding_mask": [encoder_padding_masks], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @staticmethod + def conv_layer_stride(args): + # TODO: make it configurable from the args + return 4 + + SpeechEncoder.__name__ = klass.__name__ + return SpeechEncoder diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 1f556107a2..814924ec97 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -422,6 +422,15 @@ def s2t_transformer_s(args): base_architecture(args) +@register_model_architecture("s2t_transformer", "s2t_transformer_xs") +def s2t_transformer_xs(args): + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) + args.dropout = getattr(args, "dropout", 0.3) + s2t_transformer_s(args) + + @register_model_architecture("s2t_transformer", "s2t_transformer_sp") def s2t_transformer_sp(args): args.encoder_layers = getattr(args, "encoder_layers", 16) diff --git a/fairseq/models/speech_to_text/utils.py b/fairseq/models/speech_to_text/utils.py new file mode 100644 index 0000000000..573f8537c9 --- /dev/null +++ b/fairseq/models/speech_to_text/utils.py @@ -0,0 +1,564 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import logging +from collections.abc import Iterable +from itertools import repeat +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + + +# ------------------------------------------------------------------------------ +# assert_equal() +# ------------------------------------------------------------------------------ + + +def assert_equal(value1, value2, name1=None, name2=None): + """Asserts two values are equal otherwise raise an error.""" + + str_name1 = "" if name1 is None else "{} ".format(name1) + str_name2 = "" if name2 is None else "{} ".format(name2) + if value1 != value2: + str_value1 = "{}" if name1 is None else "({})" + str_value1 = str_value1.format(value1) + str_value2 = "{}" if name2 is None else "({})" + str_value2 = str_value2.format(value2) + raise ValueError( + "Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) + ) + + +def fill_config(config, key, value): + if value is not None: + if key not in config or config[key] is None: + config[key] = value + assert_equal(value, config[key], "value", f'config["{key}"]') + + +# ------------------------------------------------------------------------------ +# check_and_return_expected() +# ------------------------------------------------------------------------------ + + +def check_and_return_expected(value, undefined_value, expected_value, name=None): + """ + Return the expected value while checking if the given value is undefined or + equal to the expected value. + """ + if (undefined_value is None and value is None) or (undefined_value == value): + return expected_value + if value != expected_value: + str_name = "" if name is None else "{} ".format(name) + str_value = "{}" if name is None else "({})" + str_value = str_value.format(value) + raise ValueError( + "Expected {}{} == {}".format(str_name, str_value, expected_value) + ) + return expected_value + + +# ------------------------------------------------------------------------------ +# get_time_axis() +# ------------------------------------------------------------------------------ + + +def get_time_axis(layout): + """ + Extract the time axis from the layout, for example for breaking sequence into + segments. + """ + if layout in ["TB", "TBD"]: + return 0 + if layout in ["BT", "BTD"]: + return 1 + if layout in ["BCTD"]: + return 2 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# get_batch_axis() +# ------------------------------------------------------------------------------ + + +def get_batch_axis(layout): + """ + Extract the batch axis from the layout + """ + if layout in ["TB", "TBD"]: + return 1 + if layout in ["BT", "BTD", "BCTD"]: + return 0 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# monotonically_increasing_and_bounded() +# ------------------------------------------------------------------------------ + + +def monotonically_increasing_and_bounded(iterable, min=None, max=None): + """ + Check if the elements in the given iterable are monotonically increasing and + bounded by upper/lower bounds. + """ + if not isinstance(iterable, Iterable): + raise TypeError( + "Expected iterable to be of type Iterable, got ({})".format( + iterable.__class__.__name__ + ) + ) + for i in range(len(iterable)): + if min is not None and iterable[i] < min: + return False + if max is not None and iterable[i] > max: + return False + if i > 0 and iterable[i] <= iterable[i - 1]: + return False + return True + + +# ------------------------------------------------------------------------------ +# to_pair() +# ------------------------------------------------------------------------------ + + +def to_pair(value, name): + """Make a pair (of type tuple) of given value.""" + if isinstance(value, Iterable): + if len(value) != 2: + raise ValueError( + "Expected `{}` to have exactly 2 elements, got: ({})".format( + name, value + ) + ) + return value + return tuple(repeat(value, 2)) + + +# ------------------------------------------------------------------------------ +# infer_conv_output_attrs() +# ------------------------------------------------------------------------------ + + +# TODO(cfyeh): figure out if we can get `output_dim` without calling the module. +def infer_conv_output_attrs( + module, input_channels, input_dim, batch_size=1, max_length=8 +): + """Get output attributes of a module with input.""" + input = torch.randn(batch_size, input_channels, max_length, input_dim) + output = module(input) + output_channels = output.shape[1] + output_dim = output.shape[-1] + return output_channels, output_dim + + +# ------------------------------------------------------------------------------ +# NoOp +# ------------------------------------------------------------------------------ + + +class NoOp(torch.nn.Module): + """ + NoOp simply passes the input as the output. + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +# ------------------------------------------------------------------------------ +# Permute: a torch.nn.Module applies permutation on the input tensor. +# ------------------------------------------------------------------------------ + + +class Permute(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, input: Tensor) -> Tensor: + return input.permute(self.dims).contiguous() + + +# ------------------------------------------------------------------------------ +# lengths_to_padding_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_padding_mask(lengths: Tensor) -> Tensor: + """Convert lengths of shape (B, ) to padding mask.""" + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange( # [0, ..., T-1] + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(batch_size, max_length) >= lengths.unsqueeze(1) + + return padding_mask + + +# ------------------------------------------------------------------------------ +# lengths_to_attention_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_attention_mask( + lengths: Tensor, + left_context: Optional[int] = None, + right_context: Optional[int] = None, +) -> Optional[Tensor]: + """ + Generate attention mask based on (lengths, left_context, right_context). + left_context is None means unlimited left context. + right_context is None means unlimited right context. + """ + + if left_context is None and right_context is None: + return None + + max_length = int(torch.max(lengths).item()) + + # For example, with `max_length` == 5, + # indices = tensor([ + # [ 0, 1, 2, 3, 4, 5], + # [-1, 0, 1, 2, 3, 4], + # [-2, -1, 0, 1, 2, 3], + # [-3, -2, -1, 0, 1, 2], + # [-4, -3, -2, -1, 0, 1], + # [-5, -4, -3, -2, -1, 0], + # ]) + + # In some cases the second torch.arange is created on cpu which causes a + # failure. Adding the device option to guard against it. + indices = torch.arange( + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(max_length, max_length) - torch.arange( + max_length, device=lengths.device + ).view( + max_length, -1 + ) + + # For example, with `max_length` == 5, + # bool_mask = tensor([ + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + bool_mask = ( + torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) + ) + + # For example, with `max_length` == 5, left_context == 2 + # left_mask = tensor([ + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [False, True, True, True, True], + # [False, False, True, True, True], + # ]) + if left_context is not None: + left_mask = indices >= -left_context + bool_mask = bool_mask & left_mask + + # For example, with `max_length` == 5, right_context == 1 + # right_mask = tensor([ + # [True, True, False, False, False], + # [True, True, True, False, False], + # [True, True, True, True, False], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + if right_context is not None: + right_mask = indices <= right_context + bool_mask = bool_mask & right_mask + + bool_mask = (~bool_mask).to(device=lengths.device) + return bool_mask + + +# ------------------------------------------------------------------------------ +# infer_output_norm() +# ------------------------------------------------------------------------------ + + +def infer_output_norm(module, output_norm=None): + """ + Infer the output norm (string and module) needed on the module gvien desired + output normalization. + """ + if output_norm == module.output_norm(): + # output_norm already matches module.output_norm(). + return (None, NoOp()) + + if output_norm is None and module.output_norm() is not None: + logger = logging.getLogger("infer_output_norm()") + logger.warning( + "trying to set output_norm ({}) ".format(output_norm) + + "but got module.output_norm() ({}), ".format(module.output_norm()) + + "the combined output_norm() will be ({})".format(module.output_norm()) + ) + return (None, NoOp()) + + if output_norm == "log_softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) + + if output_norm == "softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("softmax", torch.nn.Softmax(dim=-1)) + + raise ValueError( + "output_norm ({}) not in ".format(output_norm) + + "supported list = [None, softmax, log_softmax]" + ) + + +# ------------------------------------------------------------------------------ +# infer_channels_from_layout() +# ------------------------------------------------------------------------------ + + +def infer_channels_from_layout(layout, channels): + """Extract the number of channels from the layout.""" + if layout in ("TBD", "BTD"): + if channels is not None and channels != 1: + raise ValueError( + "Expected channels ({}) to be 1 for layout = {}".format( + channels, layout + ) + ) + if channels is None: + return 1 + return channels + + +# ------------------------------------------------------------------------------ +# pad_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def pad_sequence( + sequence: Tensor, + time_axis: int, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> Tensor: + """Pad extra left/right contexts to the sequence.""" + + if extra_left_context == 0 and extra_right_context == 0: + return sequence + + tensors_to_concat = [] + + if extra_left_context: + size = (extra_left_context,) + fill_value = 0 + indices = torch.full( + size=size, + fill_value=fill_value, + dtype=torch.long, + device=sequence.device, + ) + left_padding = torch.index_select(sequence, time_axis, indices) + tensors_to_concat.append(left_padding) + + tensors_to_concat.append(sequence) + + # NOTE(cfyeh): for efficiency reason we pad 0 instead of the last frame for + # extra right contexts. + if extra_right_context: + size = list(sequence.shape) + size[time_axis] = extra_right_context + right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) + tensors_to_concat.append(right_padding) + + padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) + return padded_sequence + + +# ------------------------------------------------------------------------------ +# sequence_to_segments() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def sequence_to_segments( + sequence: Tensor, + time_axis: int, + lengths: Tensor, + segment_size: Optional[int] = None, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> List[Tuple[Tensor, Tensor]]: + """Breaks sequence into segments.""" + + sequence = pad_sequence( + sequence=sequence, + time_axis=time_axis, + extra_left_context=extra_left_context, + extra_right_context=extra_right_context, + ) + + lengths = lengths + extra_left_context + extra_right_context + + segments: List[Tuple[Tensor, Tensor]] = [] + + if segment_size is None: + segments.append((sequence, lengths)) + return segments + + offset = 0 + end = sequence.shape[time_axis] + step = segment_size + size = extra_left_context + segment_size + extra_right_context + + while offset + extra_left_context + extra_right_context < end: + clamped_size = min(size, end - offset) + segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) + indices = torch.arange( + start=offset, + end=(offset + clamped_size), + step=1, + dtype=torch.long, + device=sequence.device, + ) + segment_tensor = torch.index_select(sequence, time_axis, indices) + segments.append((segment_tensor, segment_lengths)) + offset = offset + step + + return segments + + +# ------------------------------------------------------------------------------ +# segments_to_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def segments_to_sequence( + segments: List[Tuple[Tensor, Tensor]], time_axis: int +) -> Tuple[Tensor, Tensor]: + """Concatenate segments into a full sequence.""" + if len(segments) == 1: + return segments[0] + + tensors_to_concat: List[Tensor] = [] + lengths_to_stack: List[Tensor] = [] + + for tensor, lengths in segments: + tensors_to_concat.append(tensor) + lengths_to_stack.append(lengths) + + sequence = torch.cat(tensors_to_concat, dim=time_axis) + lengths = torch.stack(lengths_to_stack, dim=0) + lengths = torch.sum(lengths, dim=0) + + return sequence, lengths + + +def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + batch_first: whether to return a (B, T) tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = False for t < lengths[b] and True otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) > lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +# ------------------------------------------------------------------------------ +# attention suppression +# ------------------------------------------------------------------------------ + + +def attention_suppression(attention_weights: Tensor, scale: float): + # B, H, qlen, klen -> B, H, qlen, 1 + attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) + attention_nozeros = attention_prob.to(torch.bool) + nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) + + # For very sparse situation, we need get round about 0s + key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) + + # nozeros_sum should > 1 + key_mean = key_sum / (nozeros_sum + 1e-8) + + # std calculation + dis = (attention_prob - key_mean) * (attention_prob - key_mean) + + # if attention_prob[i] < threshold, then dis_masked[i] = 0; for all i + dis_masked = torch.where( + attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) + ) + + key_var = torch.sum(dis_masked, dim=-1, keepdim=True) + key_var = key_var / (nozeros_sum - 1.0 + 1e-8) + key_std = torch.sqrt(key_var) + key_thread = key_mean - scale * key_std + + # if attention_prob[i] >= key_thread, then attention_prob[i] + # , otherwise "-inf" + inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() + inf_tensor[:] = float("-inf") + attention_weights_float = torch.where( + attention_prob < key_thread, + inf_tensor, + attention_weights.float(), + ) + + return attention_weights_float.type_as(attention_weights) + + +def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): + return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 362d9b28d6..297807c31a 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from fairseq import utils +from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -35,6 +36,9 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + @register_model("transformer") class TransformerModel(FairseqEncoderDecoderModel): """ @@ -190,6 +194,18 @@ def add_args(parser): help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + 'minimum number of params for a layer to be wrapped with FSDP() when ' + 'training with --ddp-backend=fully_sharded. Smaller values will ' + 'improve memory efficiency, but may make torch.distributed ' + 'communication less efficient due to smaller input sizes. This option ' + 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' + '--offload-activations are passed.' + ) + ) # fmt: on @classmethod @@ -240,6 +256,13 @@ def build_model(cls, args, task): args.checkpoint_activations = True # offloading implies checkpointing encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + if not args.share_all_embeddings: + min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + # fsdp_wrap is a no-op when --ddp-backend != fully_sharded + encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) + decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod @@ -325,6 +348,7 @@ class TransformerEncoder(FairseqEncoder): """ def __init__(self, args, dictionary, embed_tokens): + self.args = args super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) @@ -382,9 +406,17 @@ def __init__(self, args, dictionary, embed_tokens): def build_encoder_layer(self, args): layer = TransformerEncoderLayer(args) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( @@ -433,19 +465,68 @@ def forward( hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. """ + return self.forward_scriptable(src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + # account for padding while computing the representation + if encoder_padding_mask is not None: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + # B x T x C -> T x B x C x = x.transpose(0, 1) - # compute padding mask - encoder_padding_mask = src_tokens.eq(self.padding_idx) - encoder_states = [] + if return_all_hiddens: + encoder_states.append(x) + # encoder layers for layer in self.layers: - x = layer(x, encoder_padding_mask) + x = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None + ) if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) @@ -454,7 +535,7 @@ def forward( x = self.layer_norm(x) # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in - # `foward` so we use a dictionary instead. + # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. return { @@ -673,9 +754,17 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) - if getattr(args, "checkpoint_activations", False): + checkpoint = getattr(args, "checkpoint_activations", False) + if checkpoint: offload_to_cpu = getattr(args, "offload_activations", False) layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = ( + getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) + if not checkpoint else 0 + ) + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( @@ -776,13 +865,11 @@ def extract_features_scriptable( alignment_layer = self.num_layers - 1 # embed positions - positions = ( - self.embed_positions( + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) - if self.embed_positions is not None - else None - ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index edf62b12b3..fca9470e5e 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -14,7 +14,9 @@ register_model, register_model_architecture, ) -from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.transformer import ( + DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder +) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from omegaconf import II @@ -126,15 +128,6 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) - decoder_layerdrop: float = field( - default=0.0, metadata={"help": "LayerDrop probability for decoder"} - ) - decoder_layers_to_keep: Optional[str] = field( - default=None, - metadata={ - "help": "which layers to *keep* when pruning as a comma-separated list" - }, - ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) @@ -148,6 +141,17 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) + # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + # config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) quant_noise_pq: float = field( default=0.0, metadata={"help": "iterative PQ quantization noise at training time"}, @@ -156,13 +160,27 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=8, metadata={"help": "block size of quantization noise at training time"}, ) - # TODO common var add to parent quant_noise_scalar: float = field( default=0.0, metadata={ "help": "scalar quantization noise and scalar quantization at training time" }, ) + # config for Fully Sharded Data Parallel (FSDP) training + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": ( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + ) + } + ) + # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") @@ -289,7 +307,7 @@ def base_lm_architecture(args): args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) @@ -394,6 +412,18 @@ def transformer_lm_gpt2_small(args): base_lm_architecture(args) +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") +def transformer_lm_gpt2_tiny(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 1) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + @register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") def transformer_lm_gpt2_medium(args): args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) @@ -416,3 +446,84 @@ def transformer_lm_gpt2_big(args): args.attention_dropout = getattr(args, "attention_dropout", 0.1) args.activation_fn = getattr(args, "activation_fn", "gelu") base_lm_architecture(args) + + +def base_gpt3_architecture(args): + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.dropout = getattr(args, "dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") +def transformer_lm_gpt3_small(args): + # 125M params + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") +def transformer_lm_gpt3_medium(args): + # 350M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") +def transformer_lm_gpt3_large(args): + # 760M params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") +def transformer_lm_gpt3_xl(args): + # 1.3B params + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 24) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") +def transformer_lm_gpt3_2_7(args): + # 2.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") +def transformer_lm_gpt3_6_7(args): + # 6.7B params + args.decoder_layers = getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") +def transformer_lm_gpt3_13(args): + # 13B params + args.decoder_layers = getattr(args, "decoder_layers", 40) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 40) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") +def transformer_lm_gpt3_175(args): + # 175B params + args.decoder_layers = getattr(args, "decoder_layers", 96) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 12288) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 96) + base_gpt3_architecture(args) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 783ebcfe6b..6999dca2d9 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -26,7 +26,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, index_put, is_xla_tensor EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) @@ -330,47 +330,52 @@ def build_model(cls, cfg: Wav2Vec2Config, task=None): return cls(cfg) - def apply_mask(self, x, padding_mask): + def apply_mask( + self, x, padding_mask, + mask_indices=None, mask_channel_indices=None, + ): B, T, C = x.shape if self.mask_prob > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - x[mask_indices] = self.mask_emb + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) return x, mask_indices - def sample_negatives(self, y, num): + def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) @@ -378,8 +383,9 @@ def sample_negatives(self, y, num): bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C + # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz - high = tsz + high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" @@ -436,14 +442,40 @@ def compute_preds(self, x, y, negatives): logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) - logits /= self.logit_temp + logits = logits / self.logit_temp - if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + if is_xla_tensor(logits) or neg_is_pos.any(): + fillval = -float(2**30) + if not hasattr(self, '_inftensor'): + self._inftensor = ( + torch.tensor(fillval).to(x.device) + if is_xla_tensor(logits) else + float("-inf") + ) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length(input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]) + + return input_lengths.to(torch.long) + + def forward( + self, source, padding_mask=None, mask=True, features_only=False, + mask_indices=None, mask_channel_indices=None, + padding_count=None, + ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) @@ -460,11 +492,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): unmasked_features = features.clone() if padding_mask is not None: - extra = padding_mask.size(1) % features.size(1) - if extra > 0: - padding_mask = padding_mask[:, :-extra] - padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) - padding_mask = padding_mask.all(-1) + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() if self.post_extract_proj is not None: features = self.post_extract_proj(features) @@ -487,8 +526,14 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.project_inp(features) if mask: - x, mask_indices = self.apply_mask(features, padding_mask) - if mask_indices is not None: + x, mask_indices = self.apply_mask( + features, padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if not is_xla_tensor(x) and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1) ) @@ -515,12 +560,18 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) - negs, _ = self.sample_negatives(neg_cands, y.size(1)) + neg_cands = self.quantizer( + unmasked_features, produce_targets=False + )["x"] + negs, _ = self.sample_negatives( + neg_cands, y.size(1), padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( @@ -535,12 +586,20 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs, _ = self.sample_negatives( + unmasked_features, y.size(1), + padding_count=padding_count, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, y.size(1), padding_count=padding_count, + ) - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + if not is_xla_tensor(x): + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) @@ -549,7 +608,9 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): x = self.final_proj(x) x = self.compute_preds(x, y, negs) - result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} + result = { + "x": x, "padding_mask": padding_mask, "features_pen": features_pen, + } if prob_ppl is not None: result["prob_perplexity"] = prob_ppl @@ -737,11 +798,11 @@ def forward(self, x, padding_mask=None): def extract_features(self, x, padding_mask=None): if padding_mask is not None: - x[padding_mask] = 0 + x = index_put(x, padding_mask, 0) x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index bbd2ab9ec5..e8a1d03eb2 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field from omegaconf import MISSING, II, open_dict -from typing import Any +from typing import Optional, Any from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass import FairseqDataclass @@ -127,7 +127,27 @@ class Wav2Vec2AsrConfig(FairseqDataclass): @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): - pass + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + conv_feature_layers: Optional[str] = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": ( + "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + ), + }, + ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) @register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) @@ -158,7 +178,7 @@ def get_normalized_probs(self, net_output, log_probs): def get_logits(self, net_output): logits = net_output["encoder_out"] - padding = net_output["encoder_padding_mask"] + padding = net_output["padding_mask"] if padding is not None and padding.any(): padding = padding.T logits[padding][...,0] = 0 @@ -220,6 +240,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) + autoregressive: bool = II("task.autoregressive") @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) @@ -231,6 +252,8 @@ def __init__(self, encoder, decoder): def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" + assert cfg.autoregressive, "Please set task.autoregressive=true for seq2seq asr models" + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): @@ -359,7 +382,7 @@ def forward(self, source, padding_mask, tbc=True, **kwargs): return { "encoder_out": x, # T x B x C - "encoder_padding_mask": padding_mask, # B x T + "encoder_padding_mask": padding_mask.transpose(0, 1), # T x B "padding_mask": padding_mask, } @@ -539,7 +562,7 @@ def extract_features( x, attn, _ = layer( x, encoder_out["encoder_out"] if encoder_out is not None else None, - encoder_out["encoder_padding_mask"] + encoder_out["padding_mask"] if encoder_out is not None else None, incremental_state, diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 03e70f4279..f9ada37bde 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -103,7 +103,7 @@ def upgrade_state_dict_named(self, state_dict, name): state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] - def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): + def forward(self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` @@ -135,6 +135,7 @@ def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): key=x, value=x, key_padding_mask=encoder_padding_mask, + need_weights=False, attn_mask=attn_mask, ) x = self.dropout_module(x) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 6e9c32f467..a7fb198779 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -226,6 +226,7 @@ def forward( last_state_only: bool = False, positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: is_tpu = tokens.device.type == "xla" @@ -268,7 +269,7 @@ def forward( inner_states.append(x) for layer in self.layers: - x, _ = layer(x, self_attn_padding_mask=padding_mask) + x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) if not last_state_only: inner_states.append(x) diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index 1a581bc010..a5366d6243 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -22,12 +22,13 @@ class OptimizerAndSchedulerConfig(FairseqDataclass): optimizer: Any = None lr_scheduler: Optional[Any] = None - lr: List[float] = II("optimization.lr") + lr: List = II("optimization.lr") + lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers @dataclass class CompositeOptimizerConfig(FairseqDataclass): - groups: Dict[str, OptimizerAndSchedulerConfig] = field( + groups: Dict[str, Any] = field( default_factory=lambda: {}, metadata={ "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " @@ -64,8 +65,12 @@ def __init__(self, cfg: CompositeOptimizerConfig, params): for group, group_params in groupped_params.items(): group_cfg = cfg.groups[group] with open_dict(group_cfg): - group_cfg.optimizer.lr = group_cfg.lr - group_cfg.lr_scheduler.lr = group_cfg.lr + if group_cfg.lr_float is not None: + group_cfg.optimizer.lr = [group_cfg.lr_float] + group_cfg.lr_scheduler.lr = [group_cfg.lr_float] + else: + group_cfg.optimizer.lr = group_cfg.lr + group_cfg.lr_scheduler.lr = group_cfg.lr self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) if group_cfg.lr_scheduler is not None: self.lr_schedulers[group] = build_lr_scheduler( diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index fad5a64ecb..5e935df1a5 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -107,6 +107,10 @@ def __init__( self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode ) + @property + def supports_flat_params(self): + return True + @torch.no_grad() def step(self, closure=None): loss = None diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py index c5da604220..43f9be37b9 100644 --- a/fairseq/optim/dynamic_loss_scaler.py +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -10,7 +10,7 @@ def __init__( init_scale=2.0 ** 15, scale_factor=2.0, scale_window=2000, - tolerance=0.05, + tolerance=0.0, threshold=None, min_loss_scale=1e-4, ): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index e0b069f172..00ea1bbb76 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -322,6 +322,10 @@ def set_lr(self, lr): def all_reduce_grads(self, module): self.fp32_optimizer.all_reduce_grads(module) + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + class _MemoryEfficientFP16OptimizerMixin(object): def __init__(self, *args, **kwargs): @@ -442,6 +446,10 @@ def zero_grad(self): else: self._multiply_factor = 1.0 + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + class MemoryEfficientFP16Optimizer( _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer @@ -461,8 +469,10 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): - if not optimizer.supports_memory_efficient_fp16: + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 38b57fe54c..51f58359ed 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import math -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d9321577bb..0f87bb5d7e 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index 3d025a23ca..9d7f2eb9e5 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -5,11 +5,11 @@ from typing import Any, Dict -import torch +from fairseq.distributed import utils try: - from fairscale.optim import OSS, utils + from fairscale.optim import OSS _has_fairscale = True except ImportError: @@ -38,53 +38,14 @@ def broadcast_global_state_dict( self, state_dict: Dict[str, Any] ) -> Dict[str, Any]: """ - Broadcasts the relevant parts of a global state dict from rank 0 to - all other ranks. + Broadcasts the entire state_dict to all other ranks + each rank is responsible to load their own partition of data """ - if self.rank == 0: - - # Create template state dict for all other keys not related to sharding - template_state_dict = { - key: state_dict[key] - for key in state_dict - if key not in ("param_groups", "state") - } - template_state_dict["local_state_dict"] = True - - for dst_rank in range(self.world_size): - # Get the dst_rank's param_groups shard - send_state = { - "param_groups": state_dict["param_groups"][ - state_dict["partition"][dst_rank][0] : state_dict[ - "partition" - ][dst_rank][1] - ], - "state": state_dict["state"][dst_rank], - } - send_state.update(template_state_dict) - - if dst_rank == 0: - recv_state = send_state - else: - utils.broadcast_object( - send_state, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - else: - empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) - for dst_rank in range(1, self.world_size): - state = utils.broadcast_object( - empty_buffer, - src_rank=0, - group=self.group, - dist_device=self._device, - ) - if dst_rank == self.rank: - recv_state = state - - return recv_state + return utils.broadcast_object( + state_dict, + src_rank=0, + group=self.group, + ) torch_optimizer = optimizer.optimizer optim_cls = type(torch_optimizer) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 117c6116fb..2574ab13f0 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -214,7 +214,7 @@ def _generate( raise Exception("expected src_tokens or source in net input") # bsz: total number of sentences in beam - # Note that src_tokens may have more than 2 dimenions (i.e. audio features) + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) bsz, src_len = src_tokens.size()[:2] beam_size = self.beam_size @@ -376,9 +376,7 @@ def _generate( self.search.set_src_lengths(src_lengths) if self.repeat_ngram_blocker is not None: - lprobs = self.repeat_ngram_blocker( - tokens, lprobs, bsz, beam_size, step - ) + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) # Shape: (batch, cand_size) cand_scores, cand_indices, cand_beams = self.search.step( diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 7c82777331..df073a1814 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,6 +5,7 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import logging import os import sys import torch @@ -12,10 +13,9 @@ from argparse import Namespace from dataclasses import dataclass, field from typing import Optional, Any -from omegaconf import MISSING +from omegaconf import MISSING, II from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, encoders -from fairseq.data.data_utils import post_process from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import GenerationConfig @@ -24,6 +24,9 @@ from ..logging import metrics +logger = logging.getLogger(__name__) + + class LabelEncoder(object): def __init__(self, dictionary): self.dictionary = dictionary @@ -58,7 +61,7 @@ class AudioPretrainingConfig(FairseqDataclass): default=None, metadata={"help": "max sample size to crop to for batching"} ) min_sample_size: Optional[int] = field( - default=None, metadata={"help": "min sample size to crop to for batching"} + default=None, metadata={"help": "min sample size to skip small examples"} ) # Options for reporting WER metrics during validation. Only applicable to @@ -87,6 +90,37 @@ class AudioPretrainingConfig(FairseqDataclass): "adds 'prev_output_tokens' to input and appends eos to target" }, ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "number of buckets" + }, + ) + precompute_mask_indices: bool = field( + default=False, + metadata={ + "help": "flag to compute mask indices in data preparation.", + }, + ) + # The following are needed to precompute mask and mask channel indices + # before model's forward. + mask_length: Optional[int] = II("model.mask_length") + mask_prob: Optional[float] = II("model.mask_prob") + mask_selection: Optional[str] = II("model.mask_selection") + mask_other: Optional[float] = II("model.mask_other") + no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") + mask_min_space: Optional[int] = II("model.mask_min_space") + mask_channel_length: Optional[int] = II("model.mask_channel_length") + mask_channel_prob: Optional[float] = II("model.mask_channel_prob") + mask_channel_selection: Optional[str] = II("model.mask_channel_selection") + mask_channel_other: Optional[float] = II("model.mask_channel_other") + no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") + mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") + + conv_feature_layers: Optional[str] = II("model.conv_feature_layers") + encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") + + tpu: bool = II("common.tpu") @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) @@ -98,16 +132,14 @@ class AudioPretrainingTask(FairseqTask): def __init__( self, cfg: AudioPretrainingConfig, - source_dictionary=None, - target_dictionary=None, ): super().__init__(cfg) - self._target_dictionary = target_dictionary - self._source_dictionary = source_dictionary if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" self.blank_symbol = "" + self.state.add_factory("target_dictionary", self.load_target_dictionary) + @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): """Setup the task (e.g., load dictionaries). @@ -116,15 +148,41 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): cfg (AudioPretrainingConfig): configuration of this task """ - if cfg.labels: - dict_path = os.path.join(cfg.data, f"dict.{cfg.labels}.txt") - target_dictionary = Dictionary.load(dict_path) - else: - target_dictionary = None + return cls(cfg) - return cls(cfg, target_dictionary=target_dictionary) + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join( + self.cfg.data, f"dict.{self.cfg.labels}.txt" + ) + return Dictionary.load(dict_path) + return None + + def _get_mask_precompute_kwargs(self, cfg): + if self.cfg.precompute_mask_indices or self.cfg.tpu: + args = [ + 'mask_length', + 'mask_prob', + 'mask_selection', + 'mask_other', + 'no_mask_overlap', + 'mask_min_space', + 'mask_channel_length', + 'mask_channel_prob', + 'mask_channel_selection', + 'mask_channel_other', + 'no_mask_channel_overlap', + 'mask_channel_min_space', + 'encoder_embed_dim', + 'conv_feature_layers', + ] + return {arg: cfg[arg] for arg in args} + else: + return {} - def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + def load_dataset( + self, split: str, task_cfg: FairseqDataclass = None, **kwargs + ): data_path = self.cfg.data task_cfg = task_cfg or self.cfg @@ -136,17 +194,27 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=task_cfg.sample_rate, + sample_rate=task_cfg.get('sample_rate', self.cfg.sample_rate), max_sample_size=self.cfg.max_sample_size, - min_sample_size=self.cfg.max_sample_size, - min_length=self.cfg.min_sample_size, + min_sample_size=self.cfg.min_sample_size, pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask_indices=( + self.cfg.precompute_mask_indices or self.cfg.tpu + ), + **self._get_mask_precompute_kwargs(task_cfg), ) + if self.cfg.tpu and task_cfg['mask_channel_prob'] == 0.0: + logger.info( + "Pretraining on TPUs may suffer convergence " + "issues when training with `mask_channel_prob` value of " + "0. You may want to set this to a low value close to 0." + ) + if task_cfg.labels: label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") - labels = [] with open(label_path, "r") as f: labels = [ line for i, line in enumerate(f) @@ -166,18 +234,18 @@ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=task_cfg.autoregressive, + add_to_input=task_cfg.get('autoregressive', False), ) @property def source_dictionary(self): - return self._source_dictionary + return None @property def target_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language model.""" - return self._target_dictionary + return self.state.target_dictionary def max_positions(self): """Maximum input length supported by the encoder.""" diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 34264bdc01..375b5277b9 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,7 +7,7 @@ import os import warnings from argparse import Namespace -from typing import List +from typing import Any, Callable, Dict, List import torch from fairseq import metrics, search, tokenizer, utils @@ -20,10 +20,45 @@ logger = logging.getLogger(__name__) +class StatefulContainer(object): + + _state: Dict[str, Any] = dict() + _factories: Dict[str, Callable[[], Any]] = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + class FairseqTask(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. """ @classmethod @@ -42,10 +77,16 @@ def logging_outputs_can_be_summed(criterion) -> bool: """ return criterion.logging_outputs_can_be_summed() + cfg: FairseqDataclass + datasets: Dict[str, FairseqDataset] + dataset_to_epoch_iter: Dict[FairseqDataset, Any] + state: StatefulContainer = None + def __init__(self, cfg: FairseqDataclass, **kwargs): self.cfg = cfg - self.datasets = {} - self.dataset_to_epoch_iter = {} + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() @classmethod def load_dictionary(cls, filename): @@ -514,6 +555,15 @@ def reduce_metrics(self, logging_outputs, criterion): criterion.__class__.reduce_metrics(logging_outputs) + def state_dict(self): + if self.state is not None: + return self.state.state_dict + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + if self.state is not None: + self.state.merge_state_dict(state_dict) + def max_positions(self): """Return the max input length allowed by the task.""" return None diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 4a44d967b3..a3847733a1 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -91,6 +91,8 @@ class LanguageModelingConfig(FairseqDataclass): ) data_buffer_size: int = II("dataset.data_buffer_size") tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") @register_task("language_modeling", dataclass=LanguageModelingConfig) @@ -184,7 +186,9 @@ def build_model(self, args): return model - def load_dataset(self, split, epoch=1, combine=False, **kwargs): + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> MonolingualDataset: """Load a given dataset split. Args: @@ -196,13 +200,12 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) + # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: - raise FileNotFoundError( - "Dataset not found: {} ({})".format(split, split_path) - ) + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, @@ -212,7 +215,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): self.args.tokens_per_sample, self.args.seed, ) - dataset = TokenBlockDataset( dataset, dataset.sizes, @@ -221,6 +223,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( @@ -228,7 +233,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): and self.args.sample_break_mode != "none" ) - self.datasets[split] = self._initialize_dataset( + self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, @@ -239,9 +244,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): add_bos_token=self.args.add_bos_token, ) - def _initialize_dataset(self, **kwargs): - return MonolingualDataset(**kwargs) - def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): """ Generate batches for inference. We prepend an eos token to src_tokens diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 90635d882f..331f685495 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -394,7 +394,11 @@ def reduce_metrics(self, logging_outputs, criterion): if self.cfg.eval_bleu: def sum_logs(key): - return sum(log.get(key, 0) for log in logging_outputs) + import torch + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result counts, totals = [], [] for i in range(EVAL_BLEU_ORDER): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 49129a7fb0..6c5ac3654a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -25,6 +25,8 @@ from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import OmegaConf + logger = logging.getLogger(__name__) @@ -61,15 +63,31 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): else: self.device = torch.device("cpu") + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + else: + if self.cfg.distributed_training.cpu_offload: + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + # copy model and criterion to current device/dtype self._criterion = criterion self._model = model - if cfg.common.fp16: - self._criterion = self._criterion.half() - self._model = self._model.half() - elif cfg.common.bf16: - self._criterion = self._criterion.to(dtype=torch.bfloat16) - self._model = self._model.to(dtype=torch.bfloat16) + if cfg.distributed_training.ddp_backend != "fully_sharded": + if cfg.common.fp16: + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, @@ -169,8 +187,27 @@ def use_distributed_wrapper(self) -> bool: return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.distributed_training.cpu_offload ) + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return True + else: + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(self.data_parallel_rank) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" + @property def criterion(self): if self._wrapped_criterion is None: @@ -222,7 +259,20 @@ def _build_optimizer(self): ) ) - if self.cfg.common.fp16 or self.cfg.common.bf16: + if ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and self.cfg.common.fp16 + ): + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " @@ -242,6 +292,16 @@ def _build_optimizer(self): logger.info("NOTE: your device may support faster training with --fp16") self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + if self.cfg.distributed_training.ddp_backend == "fully_sharded": + assert not self.cfg.optimization.use_bmuf, \ + "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + if self.cfg.optimization.use_bmuf: self._optimizer = optim.FairseqBMUF( self.cfg.bmuf, @@ -274,23 +334,50 @@ def consolidate_optimizer(self): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg) + if OmegaConf.is_config(self.cfg) else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + } + } + if not self.cfg.checkpoint.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + return state_dict + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - if self.is_data_parallel_master: # only save one checkpoint - extra_state["metrics"] = metrics.state_dict() - extra_state["previous_training_time"] = self.cumulative_training_time() - checkpoint_utils.save_state( + logger.info(f"Saving checkpoint to {filename}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, filename, - self.cfg, - self.model.state_dict(), - self.criterion, - self.optimizer, - self.lr_scheduler, - self.get_num_updates(), - self._optim_history, - extra_state, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {filename}") + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, @@ -316,6 +403,8 @@ def load_checkpoint( # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or self.cfg.distributed_training.ddp_backend == "fully_sharded" ) if load_on_all_ranks or self.data_parallel_rank == 0: @@ -353,10 +442,14 @@ def load_checkpoint( self.model.load_state_dict( state["model"], strict=True, model_cfg=self.cfg.model ) + # save memory for later steps + del state["model"] if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) + del state["criterion"] + except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " @@ -373,10 +466,10 @@ def load_checkpoint( last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ - ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." + ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ - ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) @@ -399,7 +492,7 @@ def load_checkpoint( self.lr_step(epoch) - if itr_state["version"] >= 2 and itr_state["iterations_in_epoch"] == 0: + if itr_state.get("version", 1) >= 2 and itr_state["iterations_in_epoch"] == 0: # reset meters at start of epoch reset_meters = True @@ -439,6 +532,7 @@ def get_train_iterator( epoch=epoch, combine=combine, data_selector=data_selector, + tpu=self.tpu, ) batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(self.cfg.dataset.train_subset), @@ -591,9 +685,7 @@ def maybe_no_sync(): # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass - import torch_xla.core.xla_model as xm - - xm.mark_step() + self._xla_markstep_and_send_to_cpu() if is_dummy_batch: if torch.is_tensor(sample_size): @@ -715,7 +807,6 @@ def maybe_no_sync(): if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm - xm.mark_step() # only log stats every log_interval steps @@ -732,7 +823,7 @@ def maybe_no_sync(): metrics.log_scalar( "gb_total", gb_total, priority=1600, round=1, weight=0 ) - + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm ) @@ -785,9 +876,7 @@ def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous("valid_step") # wait for all workers - xm.mark_step() with torch.no_grad(): self.model.eval() @@ -830,6 +919,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output @@ -926,7 +1017,24 @@ def set_num_updates(self, num_updates): metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) def clip_grad_norm(self, clip_norm): - return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + + def agg_norm_fn(total_norm): + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) + return total_norm ** 0.5 + + should_agg_norm = ( + self.cfg.distributed_training.ddp_backend == "fully_sharded" + and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ) + ) + return self.optimizer.clip_grad_norm( + clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None + ) def cumulative_training_time(self): if self._cumulative_training_time is None: @@ -1059,10 +1167,7 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is @@ -1111,7 +1216,7 @@ def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( torch.isfinite(tensor).all() - or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) if not is_consistent(self._grad_norm_buf): @@ -1178,6 +1283,8 @@ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): return logging_output def _check_xla_compilation(self): + if self.cfg.distributed_training.distributed_rank: + return import torch_xla.debug.metrics as met compile_stats = met.metric_data("CompileTime") @@ -1186,13 +1293,18 @@ def _check_xla_compilation(self): num_xla_compiles = compile_stats[0] if num_xla_compiles > self._num_xla_compiles: logger.warning( - "XLA compilation detected on device #{}; too many of these can lead " - "to slow training, but we expect a few in the beginning".format( - self.cfg.distributed_training.distributed_rank - ) + "XLA compilation detected; too many of these can lead " + "to slow training, but we expect a few in the beginning" ) self._num_xla_compiles = num_xla_compiles + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) + def _catalog_shared_params(module, memo=None, prefix=""): if memo is None: diff --git a/fairseq/utils.py b/fairseq/utils.py index d4bf73648b..90bb8369f2 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -110,6 +110,7 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): + def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -120,6 +121,17 @@ def _move_to_cpu(tensor): return apply_to_sample(_move_to_cpu, sample) +def move_to_tpu(sample): + + import torch_xla.core.xla_model as xm + device = xm.xla_device() + + def _move_to_tpu(tensor): + return tensor.to(device) + + return apply_to_sample(_move_to_tpu, sample) + + def get_incremental_state( module: MultiheadAttention, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], @@ -289,6 +301,9 @@ def convert_padding_direction( def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == 'xla': + return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): @@ -679,6 +694,27 @@ def tpu_data_loader(itr): ) +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == 'xla' + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 9af7568a77..d48da7dda5 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -18,16 +18,17 @@ import torch from fairseq import ( checkpoint_utils, - distributed_utils, options, quantization_utils, tasks, utils, ) from fairseq.data import iterators +from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.distributed_utils import is_master +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils +from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer @@ -49,7 +50,7 @@ def main(cfg: FairseqConfig) -> None: utils.import_user_module(cfg.common) - if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) @@ -67,6 +68,16 @@ def main(cfg: FairseqConfig) -> None: # Print args logger.info(cfg) + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) @@ -76,7 +87,11 @@ def main(cfg: FairseqConfig) -> None: assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion - model = task.build_model(cfg.model) + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) @@ -84,8 +99,8 @@ def main(cfg: FairseqConfig) -> None: logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {:,} (num. trained: {:,})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), + sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), ) ) @@ -104,14 +119,13 @@ def main(cfg: FairseqConfig) -> None: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size ) ) logger.info( - "max tokens per GPU = {} and batch size per GPU = {}".format( + "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, ) @@ -125,6 +139,9 @@ def main(cfg: FairseqConfig) -> None: # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) + if cfg.common.tpu: + import torch_xla.core.xla_model as xm + xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() @@ -157,6 +174,15 @@ def main(cfg: FairseqConfig) -> None: train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + # ioPath implementation to wait for all asynchronous file writes to complete. + if cfg.checkpoint.write_checkpoints_asynchronously: + logger.info( + "ioPath PathManager waiting for all asynchronous checkpoint " + "writes to finish." + ) + PathManager.async_close() + logger.info("ioPath PathManager finished waiting.") + def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch @@ -236,6 +262,7 @@ def train( valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() + logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i @@ -441,6 +468,10 @@ def cli_main( cfg = convert_namespace_to_omegaconf(args) + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -448,6 +479,9 @@ def cli_main( else: distributed_utils.call_main(cfg, main) + # if cfg.common.use_plasma_view: + # server.server.kill() + if __name__ == "__main__": cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index c69bb94142..90d7e4c6a9 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -u -# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -43,6 +42,13 @@ def main(cfg: DictConfig, override_args=None): if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) @@ -91,8 +97,8 @@ def main(cfg: DictConfig, override_args=None): ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, - num_shards=cfg.distributed_training.distributed_world_size, - shard_id=cfg.distributed_training.distributed_rank, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) @@ -111,7 +117,7 @@ def main(cfg: DictConfig, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if cfg.distributed_training.distributed_world_size > 1: + if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, @@ -132,9 +138,13 @@ def cli_main(): # only override args that are explicitly given on the command line override_parser = options.get_validation_parser() - override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + override_args = options.parse_args_and_arch( + override_parser, suppress_defaults=True + ) - distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args) + distributed_utils.call_main( + convert_namespace_to_omegaconf(args), main, override_args=override_args + ) if __name__ == "__main__": diff --git a/setup.py b/setup.py index d1a976104e..3670ff3cfc 100644 --- a/setup.py +++ b/setup.py @@ -256,5 +256,5 @@ def get_files(path, relative_to="fairseq"): } do_setup(package_data) finally: - if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): os.unlink(fairseq_examples) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 981ffd49cd..e10cc767b8 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -22,6 +22,7 @@ preprocess_lm_data, preprocess_summarization_data, preprocess_translation_data, + create_laser_data_and_config_json, train_translation_model, ) @@ -935,6 +936,65 @@ def test_alignment(self): ) generate_main(data_dir) + def test_laser_lstm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_lstm") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_lstm", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-bidirectional", + "--encoder-hidden-size", + "512", + "--encoder-layers", + "5", + "--decoder-layers", + "1", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + + def test_laser_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_laser_transformer") as data_dir: + laser_config_file = create_laser_data_and_config_json(data_dir) + train_translation_model( + laser_config_file.name, + "laser_transformer", + [ + "--user-dir", + "examples/laser/laser_src", + "--weighting-alpha", + "0.3", + "--encoder-embed-dim", + "320", + "--decoder-embed-dim", + "320", + "--decoder-lang-embed-dim", + "32", + "--save-dir", + data_dir, + "--disable-validation", + ], + task="laser", + lang_flags=[], + ) + def test_alignment_full_context(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_alignment") as data_dir: @@ -1637,8 +1697,9 @@ def test_activation_offloading_does_not_change_metrics(self): """Neither ----checkpoint-activations nor --offload-activations should change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) offload_logs = self._train(data_dir, ["--offload-activations"]) baseline_logs = self._train(data_dir, []) @@ -1660,8 +1721,9 @@ def test_activation_checkpointing_does_not_change_metrics(self): """--checkpoint-activations should not change loss""" with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: - create_dummy_data(data_dir, num_examples=20) - preprocess_translation_data(data_dir) + with self.assertLogs(): + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) ckpt_logs = self._train(data_dir, ["--checkpoint-activations"]) baseline_logs = self._train(data_dir, []) assert len(baseline_logs) == len(ckpt_logs) diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 617a5f7c84..0f28222633 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -9,8 +9,10 @@ import tempfile import unittest from io import StringIO +from unittest.mock import patch from fairseq import checkpoint_utils +from omegaconf import OmegaConf from tests.utils import ( create_dummy_data, @@ -87,6 +89,18 @@ def test_prune_state_dict(self): self.assertEqual(len(ensemble[0].encoder.layers), 2) self.assertEqual(len(ensemble[0].decoder.layers), 1) + def test_torch_persistent_save_async(self): + state_dict = {} + filename = "async_checkpoint.pt" + + with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: + with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) + mock_opena.assert_called_with(filename, "wb") + mock_save.assert_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9fb69a5f77..a3e3970028 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from typing import Sequence @@ -20,6 +21,12 @@ def sample(id: int, length: int): class TestDataset(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + def test_round_robin_zip_datasets(self): long_dataset = lang_pair_dataset([10, 9, 8, 11]) short_dataset = lang_pair_dataset([11, 9]) diff --git a/tests/test_export.py b/tests/test_export.py index 87e52bd7c1..b380697b9a 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -103,6 +103,19 @@ def test_export_transformer(self): scripted = torch.jit.script(model) _test_save_and_load(scripted) + @unittest.skipIf( + torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + ) + def test_export_transformer_no_token_pos_emb(self): + task, parser = get_dummy_task_and_parser() + TransformerModel.add_args(parser) + args = parser.parse_args([]) + args.no_token_positional_embeddings = True + model = TransformerModel.build_model(args, task) + scripted = torch.jit.script(model) + _test_save_and_load(scripted) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_io.py b/tests/test_file_io.py index aef5b80d18..8ebbba4a2e 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -45,3 +45,18 @@ def test_file_io_oss(self): with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f: s = f.read() self.assertEqual(s, self._tmpfile_contents) + + def test_file_io_async(self): + # ioPath `PathManager` is initialized after the first `opena` call. + try: + from fairseq.file_io import IOPathPathManager, PathManager + + self.assertIsNone(IOPathPathManager) + _asyncfile = os.path.join(self._tmpdir, "async.txt") + f = PathManager.opena(_asyncfile, "wb") + f.close() + + from fairseq.file_io import IOPathPathManager + self.assertIsNotNone(IOPathPathManager) + finally: + self.assertTrue(PathManager.async_close()) diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py new file mode 100644 index 0000000000..5a79f4b680 --- /dev/null +++ b/tests/test_multi_corpus_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from collections import OrderedDict + +import torch +from fairseq.data import LanguagePairDataset, TokenBlockDataset +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from tests.test_train import mock_dict + + +class TestMultiCorpusDataset(unittest.TestCase): + def setUp(self): + d = mock_dict() + tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1) + tokens_ds1 = TokenBlockDataset( + tokens_1, + sizes=[tokens_1.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_1 = LanguagePairDataset( + tokens_ds1, tokens_ds1.sizes, d, shuffle=False + ) + tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1) + tokens_ds2 = TokenBlockDataset( + tokens_2, + sizes=[tokens_2.size(-1)], + block_size=1, + pad=0, + eos=1, + include_targets=False, + ) + self.dataset_2 = LanguagePairDataset( + tokens_ds2, tokens_ds2.sizes, d, shuffle=False + ) + + def _test_sample_helper( + self, + distribution, + ): + m = MultiCorpusDataset( + OrderedDict({0: self.dataset_1, 1: self.dataset_2}), + distribution=distribution, + seed=0, + sort_indices=True, + ) + m.set_epoch(1) + indices = m.ordered_indices() + count_sample_from_first_dataset = 0 + items = set() + for i in indices: + item = m[i]["source"].item() + if item % 2 == 1: + count_sample_from_first_dataset += 1 + + items.add(item) + sample_from_first_ds_percentage = ( + 1.0 * count_sample_from_first_dataset / len(indices) + ) + self.assertLess( + abs(sample_from_first_ds_percentage - distribution[0]), + 0.01, + ) + self.assertEqual( + len(items), + int(min(len(self.dataset_1), len(indices) * distribution[0]) + + min(len(self.dataset_1), len(indices) * distribution[1])) + ) + print(distribution) + + def test_multi_corpus_dataset(self): + for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]: + self._test_sample_helper(distribution=distribution) diff --git a/tests/test_plasma_utils.py b/tests/test_plasma_utils.py new file mode 100644 index 0000000000..5737530e3d --- /dev/null +++ b/tests/test_plasma_utils.py @@ -0,0 +1,127 @@ +import contextlib +import unittest +import tempfile +from io import StringIO + +import numpy as np + +from tests.test_binaries import train_language_model +from tests.utils import create_dummy_data, preprocess_lm_data + +try: + from pyarrow import plasma + from fairseq.data.plasma_utils import PlasmaView, PlasmaStore + + PYARROW_AVAILABLE = True +except ImportError: + PYARROW_AVAILABLE = False + +dummy_path = 'dummy' + + +@unittest.skipUnless(PYARROW_AVAILABLE, "") +class TestPlasmaView(unittest.TestCase): + def setUp(self) -> None: + self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201 + self.path = self.tmp_file.name + self.server = PlasmaStore.start(path=self.path) + self.client = plasma.connect(self.path, num_retries=10) + + def tearDown(self) -> None: + self.client.disconnect() + self.tmp_file.close() + self.server.kill() + + def test_two_servers_do_not_share_object_id_space(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + server_2_path = self.path + with tempfile.NamedTemporaryFile() as server_1_path: + server = PlasmaStore.start(path=server_1_path.name, nbytes=10000) + arr1 = PlasmaView( + data_server_1, dummy_path, 1, plasma_path=server_1_path.name + ) + assert len(arr1.client.list()) == 1 + assert (arr1.array == data_server_1).all() + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path) + assert (arr2.array == data_server_2).all() + assert (arr1.array == data_server_1).all() + server.kill() + + def test_hash_collision(self): + data_server_1 = np.array([0, 1]) + data_server_2 = np.array([2, 3]) + arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path) + assert len(arr1.client.list()) == 1 + assert len(arr2.client.list()) == 1 + assert (arr2.array == data_server_1).all() + # New hash key based on tuples + arr3 = PlasmaView( + data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path + ) + assert ( + len(arr2.client.list()) == 2 + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr2.client.list() + ), "No new object was created by using a novel hash key" + assert ( + arr3.object_id in arr3.client.list() + ), "No new object was created by using a novel hash key" + del arr3, arr2, arr1 + + @staticmethod + def _assert_view_equal(pv1, pv2): + np.testing.assert_array_equal(pv1.array, pv2.array) + + def test_putting_same_array_twice(self): + data = np.array([4, 4, 4]) + arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path) + assert len(self.client.list()) == 1 + arr1b = PlasmaView( + data, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + arr1c = PlasmaView( + None, dummy_path, 1, plasma_path=self.path + ) # should not change contents of store + + assert len(self.client.list()) == 1 + self._assert_view_equal(arr1, arr1b) + self._assert_view_equal(arr1, arr1c) + PlasmaView( + data, dummy_path, 2, plasma_path=self.path + ) # new object id, adds new entry + assert len(self.client.list()) == 2 + + new_client = plasma.connect(self.path) + assert len(new_client.list()) == 2 # new client can access same objects + assert isinstance(arr1.object_id, plasma.ObjectID) + del arr1b + del arr1c + + def test_plasma_store_full_raises(self): + with tempfile.NamedTemporaryFile() as new_path: + server = PlasmaStore.start(path=new_path.name, nbytes=10000) + with self.assertRaises(plasma.PlasmaStoreFull): + # 2000 floats is more than 2000 bytes + PlasmaView( + np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name + ) + server.kill() + + def test_object_id_overflow(self): + PlasmaView.get_object_id("", 2 ** 21) + + def test_training_lm_plasma(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + ["--use-plasma-view", "--plasma-path", self.path], + run_validation=True, + ) diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py index ea315b4e67..c4d7b76dcd 100644 --- a/tests/test_token_block_dataset.py +++ b/tests/test_token_block_dataset.py @@ -74,6 +74,19 @@ def test_complete_break_mode(self): self.assertEqual(ds[1].tolist(), [5, 1, 1]) self.assertEqual(ds[2].tolist(), [6, 1]) + def test_4billion_tokens(self): + """Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745""" + data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000 + ds = self._build_dataset( + data, block_size=6, pad=0, eos=1, break_mode="complete" + ) + ds[-1] # __getitem__ works + start, end = ds.slice_indices[-1] + assert end > 4294967295 # data must be sufficiently large to overflow uint32 + assert not isinstance( + end + 1, float + ) # this would also raise, since np.uint64(1) + 1 => 2.0 + if __name__ == "__main__": unittest.main() diff --git a/tests/test_train.py b/tests/test_train.py index 57daa194b2..65f4683bc6 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model): "reset_lr_scheduler": False, "finetune_from_model": finetune_from_model, "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", }, "common": { "model_parallel_size": 1, diff --git a/tests/utils.py b/tests/utils.py index 178df5763e..1bf6f8d7f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import json import os import random import sys @@ -274,6 +275,43 @@ def preprocess_summarization_data(data_dir, extra_flags=None): preprocess.main(preprocess_args) +def create_laser_data_and_config_json(data_dir): + src_langs = ["de", "fr", "ru", "tr", "zh"] + tgt_langs = ["en", "es"] + config_json = {} + config_train_json = [] + src_vocab = None + tgt_vocab = None + + for src_lang in src_langs: + for tgt_lang in tgt_langs: + langpair_folder = f"{src_lang}-{tgt_lang}" + + langpair_path = os.path.join(data_dir, langpair_folder) + os.mkdir(langpair_path) + create_dummy_data(langpair_path) + preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"]) + + src_vocab = os.path.join(langpair_path, "dict.in.txt") + tgt_vocab = os.path.join(langpair_path, "dict.out.txt") + config_train_json.append( + { + "id": 0 if tgt_lang == "en" else 1, + "src": os.path.join(langpair_path, "train.in-out.in"), + "tgt": os.path.join(langpair_path, "train.in-out.out"), + } + ) + + config_json["src_vocab"] = src_vocab + config_json["tgt_vocab"] = tgt_vocab + config_json["train"] = config_train_json + + with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file: + json.dump(config_json, config_file) + + return config_file + + def train_translation_model( data_dir, arch,