Skip to content

Insertion Transformer: Flexible Sequence Generation via Insertion Operations #123

@kweonwooj

Description

@kweonwooj

Abstract

  • present the Insertion Transformer, an iterative and partially autoregressive model for sequence generation based on insertion operations
    • can generate with an arbitrary ordering
    • supports autoregressive (one insertion at a time) and partially autoregressive (simultaneous insertions at multiple locations) generation
  • IT outperforms many Non-autoregressive approaches
  • IT performs on par with original Transformer with log scale speed improvement

Details

Introduction

  • Neural Sequence Models are commonly based on autoregressive left-to-right structure both in training and inference. Although successful in many areas, autoregressive framework does not support parallel token generation or more elaborate ordering well
  • Recently, Non-autoregressive models are introduced where all the tokens are predicted simultaneously, but performance degradation is huge (Gu et al 2018, Lee et al 2018)
  • Semi-Autoregressive models (Stern 2018, Wang 2018) generates multiple tokens simultaneously, but bound in left-to-right ordering
  • this work presents flexible sequence generation framework based on insertion operation which support simultaneous generation in arbitrary ordering
    • Insertion Transformer models both token (c = content) and insertion location (l)
      screen shot 2019-02-14 at 10 09 50 am

Model Adjustments from original Transformer

  • (1) Full Decoder Self-Attention
    • remove causal self-attention in decoder and replace it with full self-attention
  • (2) Slot Representation via Concatenated Outputs
    • given n target tokens, Insertion Transformer models n+1 slots between each tokens. Each slot is represented by the concatenation of adjacent pair of tokens with special bos/eos tokens
  • (3) Model Variants
    • (3-1) Content-Location Distribution
      • Model p(c, l) in joint distribution or factorized distribution
        screen shot 2019-02-14 at 10 13 36 am
        • where H, shape(T+1 x h) is last layer of decoder and W, shape (h x C) is softmax projection layer. covers all vocabs over all locations
      • Model p(c, l) = p(c | l) * p(l) in conditional distribution
        screen shot 2019-02-14 at 10 16 00 am
        screen shot 2019-02-14 at 10 16 02 am
        • where h_l, shape(h) is l-th row of H, q, shape(h) is a learnable query vector
    • (3-2) Contextualized Vocabulary Bias
      • to increase information sharing across slots, add max pooling of final decoder hidden vector as bias. Expected to provide model with coverage info, or in propagating count information about common words
        screen shot 2019-02-14 at 10 17 40 am
    • (3-3) Mixture-of-Softmaxes Output Layer
      • given that modeling both vocab and location is a difficult modeling problem, include mixture-of-softmaxes layer proposed in Yang et al 2018 to address this issue
  • Ablation Experiment on architectural variants show that Contextual + Mixture leads to best performance, but the gap disappears when we use eos penalty
    screen shot 2019-02-14 at 10 51 51 am

Training

Balanced Binary Tree

  • use soft binary tree loss encouraging the model to assign high prob to tokens near the middle of the span using randomly generated partial canvas hypotheses
    screen shot 2019-02-14 at 10 26 08 am
# Partial Hypothesis Generation Routine
(1) randomly sample k ~ Uniform(0, |y|)
(2) shuffle index list and extract k tokens
(3) for each span of tokens not produced between extracted tokens, obtain distance measure (Eq. 10)
(4) define slot loss as a weighted sum of log-likelihood of tokens with w_i being softmax weight

screen shot 2019-02-14 at 10 25 14 am

  • full loss is average of slot losses across all locations
    screen shot 2019-02-14 at 10 25 24 am
  • taking temperature param -> 0 leads to peaked distribution and -> inf leads to uniform distribution

Termination Condition

  • (1) Slot Finalization
    • all empty spans (where no more token should be generated), we take the target to be a single end-of-slot token
    • cease decoding if all slots predict eos token
  • (2) Sequence Finalization
    • leave the slot losses undefined for empty spans and exclude them from overall loss,
    • take the slot loss at every location to be negative log-likelihood of eos token, once the entire sequence is produced and all locations are empty spans
    • this is identical to Slot Finalization at the end, but differes in generation as no signal is provided for empty slots

Training Differenes

  • unlike original Transformer, where all generation steps can be computed in a single pass for training, and previous state computations can be cached for inference, Insertion Transformer can only compute the loss for one generation step at a time under same memory constraint hence the effective batch size is reduced. Under right training conditions, performance is recovered.

Inference

Greedy Decoding

  • greedy decoding is supported for both slot finalization and sequence finalization training scheme, by taking argmax at each timestep

Parallel Decoding

  • when trained with slot finalization, parallel inference is possible
  • token is inserted in every slot at every step, and theoretically, sequence of length n can be generated in as few as log_2_(n) + 1 steps

Experiments

  • Dataset : WMT14 En2De, newstest2013 as dev and newstest2014 as test
  • transformer_base setup trained upto 1M steps with 8 x P100 GPUs
    screen shot 2019-02-14 at 10 48 55 am
  • EOS penalty : selecting EOS token only if the log-probability is at least beta different (unless model is REALLY confident about eos, do not produce eos). this is because eos token is very frequent in training time.
  • Applying eos penalty + knowledge distillation data as training target and using Parallel Decoding results in improved performance on dev set

Parallel Decoding

  • Parallel decoding on some of their stronger models lead to comparable performance as greedy, achieving faster decoding time than greedy without loss of quality
    screen shot 2019-02-14 at 10 54 14 am
  • number of decoding iteration is close to logarithmic scale than linear for parallel decoding
    screen shot 2019-02-14 at 10 55 39 am

Test Result

  • WMT14 newstest2014 result shows Insertion Transformer has on par performance as original Autoregressive Transformer with decoding step in logarithmic scale
    screen shot 2019-02-14 at 10 56 28 am

Examples of Decoding

screen shot 2019-02-14 at 10 56 11 am

Personal Thoughts

  • is there significant speed improvement in wall-clock time? decoding step reduced in log scale maybe countered by non-cachable decoding
  • balanced binary tree is a soft inductive bias that makes arbitrary ordering not too difficult for the model to learn. Pure abstract ordering learnt in Gu et al 2019 seems too difficult task for the model to learn
  • great idea with strong implementation details that lead to comparable performance as Autoregressive Transformer is impressive..!
  • just read the paper, it's well-written

Link : https://arxiv.org/pdf/1902.03249.pdf
Authors : Stern et al. 2019

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions