diff --git a/seq2seq/README.md b/seq2seq/README.md index abef56d..9b494d4 100644 --- a/seq2seq/README.md +++ b/seq2seq/README.md @@ -1,4 +1,4 @@ -运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.7版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 +运行本目录下的范例模型需要安装PaddlePaddle 2.0版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 # Sequence to Sequence (Seq2Seq) @@ -12,8 +12,7 @@ ├── download.py # 数据下载程序 ├── train.py # 训练主程序 ├── predict.py # 预测主程序 -├── seq2seq_attn.py # 带注意力机制的翻译模型程序 -└── seq2seq_base.py # 无注意力机制的翻译模型程序 +└── seq2seq_base.py # 翻译模型程序 ``` ## 简介 @@ -35,7 +34,7 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) git clone https://github.com/PaddlePaddle/hapi cd hapi export PYTHONPATH=$PYTHONPATH:`pwd` -cd examples/seq2seq +cd seq2seq ``` ## 数据介绍 @@ -157,7 +156,7 @@ python predict.py \ 使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下: ```sh -mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt +perl mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt ``` 每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序): @@ -165,13 +164,11 @@ mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt ``` > no attention tst2012 BLEU: -[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4] + tst2013 BLEU: -[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44] > with attention tst2012 BLEU: -[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4] + tst2013 BLEU: -[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63] -``` +24.94 diff --git a/seq2seq/args.py b/seq2seq/args.py index 94b07cd..0575e8e 100644 --- a/seq2seq/args.py +++ b/seq2seq/args.py @@ -36,8 +36,8 @@ def parse_args(): parser.add_argument( "--attention", type=eval, - default=False, - help="Whether use attention model") + default=True, + help="Whether use attention in model") parser.add_argument( "--optimizer", @@ -56,11 +56,13 @@ def parse_args(): type=int, default=1, help="layers number of encoder and decoder") + parser.add_argument( "--hidden_size", type=int, default=100, help="hidden size of encoder and decoder") + parser.add_argument("--src_vocab_size", type=int, help="source vocab size") parser.add_argument("--tar_vocab_size", type=int, help="target vocab size") @@ -105,6 +107,7 @@ def parse_args(): parser.add_argument( "--infer_file", type=str, help="file name for inference") + parser.add_argument( "--infer_output_file", type=str, @@ -120,7 +123,7 @@ def parse_args(): help='Whether using gpu [True|False]') parser.add_argument( - '--eager_run', type=eval, default=False, help='Whether to use dygraph') + '--eager_run', type=eval, default=True, help='Whether to use dygraph') parser.add_argument( "--enable_ce", diff --git a/seq2seq/download.py b/seq2seq/download.py index 6d2981f..9ccb470 100644 --- a/seq2seq/download.py +++ b/seq2seq/download.py @@ -47,7 +47,7 @@ def main(arguments): url = remote_path + '/' + filename tar_file = os.path.join(tar_path, filename) URLLIB.urlretrieve(url, tar_file) - print("Downloaded sucess......") + print("Downloaded success......") if __name__ == '__main__': diff --git a/seq2seq/infer.sh b/seq2seq/infer.sh new file mode 100644 index 0000000..d2a60a5 --- /dev/null +++ b/seq2seq/infer.sh @@ -0,0 +1,19 @@ +export CUDA_VISIBLE_DEVICES="4" +python predict.py \ + --attention True \ + --src_lang en --tar_lang vi \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --vocab_prefix data/en-vi/vocab \ + --infer_file data/en-vi/tst2013.en \ + --reload_model attention_models/10 \ + --infer_output_file infer_output.txt \ + --beam_size 10 \ + --use_gpu True \ + --eager_run True diff --git a/seq2seq/predict.py b/seq2seq/predict.py index 39ffd65..187333b 100644 --- a/seq2seq/predict.py +++ b/seq2seq/predict.py @@ -20,14 +20,11 @@ import numpy as np import paddle -import paddle.fluid as fluid -from paddle.fluid.layers.utils import flatten -from paddle.fluid.io import DataLoader +from paddle.io import DataLoader from paddle.static import InputSpec as Input from args import parse_args -from seq2seq_base import BaseInferModel -from seq2seq_attn import AttentionInferModel +from seq2seq import Seq2SeqInfer from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input @@ -50,7 +47,7 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, def do_predict(args): device = paddle.set_device("gpu" if args.use_gpu else "cpu") - fluid.enable_dygraph(device) if args.eager_run else None + paddle.enable_static() if not args.eager_run else None # define model inputs = [ @@ -84,14 +81,14 @@ def do_predict(args): num_workers=0, return_list=True) - model_maker = AttentionInferModel if args.attention else BaseInferModel model = paddle.Model( - model_maker( + Seq2SeqInfer( args.src_vocab_size, args.tar_vocab_size, args.hidden_size, args.hidden_size, args.num_layers, + args.attention, args.dropout, bos_id=bos_id, eos_id=eos_id, @@ -109,7 +106,7 @@ def do_predict(args): # TODO(guosheng): use model.predict when support variant length with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: for data in data_loader(): - finished_seq = model.test_batch(inputs=flatten(data))[0] + finished_seq = model.test_batch(inputs=list(data))[0] finished_seq = finished_seq[:, :, np.newaxis] if len( finished_seq.shape) == 2 else finished_seq finished_seq = np.transpose(finished_seq, [0, 2, 1]) diff --git a/seq2seq/reader.py b/seq2seq/reader.py index afa88a8..be91bea 100644 --- a/seq2seq/reader.py +++ b/seq2seq/reader.py @@ -24,9 +24,8 @@ from functools import partial import numpy as np -import paddle.fluid as fluid -from paddle.fluid.dygraph.parallel import ParallelEnv -from paddle.fluid.io import BatchSampler, DataLoader, Dataset +import paddle +from paddle.io import BatchSampler, DataLoader, Dataset def create_data_loader(args, device, for_train=True): @@ -68,7 +67,7 @@ def create_data_loader(args, device, for_train=True): num_workers=0, return_list=True) data_loaders[i] = data_loader - return data_loaders + return data_loaders, eos_id def prepare_train_input(insts, bos_id, eos_id, pad_id): @@ -76,8 +75,7 @@ def prepare_train_input(insts, bos_id, eos_id, pad_id): [inst[0] for inst in insts], pad_id=pad_id) trg, trg_length = pad_batch_data( [inst[1] for inst in insts], pad_id=pad_id) - trg_length = trg_length - 1 - return src, src_length, trg[:, :-1], trg_length, trg[:, 1:, np.newaxis] + return src, src_length, trg[:, :-1], trg[:, 1:, np.newaxis] def prepare_infer_input(insts, bos_id, eos_id, pad_id): @@ -359,9 +357,9 @@ def __init__(self, self._random.seed(seed) # for multi-devices self._distribute_mode = distribute_mode - self._nranks = ParallelEnv().nranks - self._local_rank = ParallelEnv().local_rank - self._device_id = ParallelEnv().dev_id + self._nranks = paddle.distributed.ParallelEnv().nranks + self._local_rank = paddle.distributed.ParallelEnv().local_rank + self._device_id = paddle.distributed.ParallelEnv().dev_id def __iter__(self): # global sort or global shuffle diff --git a/seq2seq/seq2seq.py b/seq2seq/seq2seq.py new file mode 100644 index 0000000..b8c094b --- /dev/null +++ b/seq2seq/seq2seq.py @@ -0,0 +1,309 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.nn import Layer, Linear, Dropout, Embedding, LayerList, RNN, LSTM, LSTMCell, RNNCellBase +from paddle.fluid.layers import BeamSearchDecoder, dynamic_decode +import paddle.nn.functional as F +import paddle.nn.initializer as I + + +class CrossEntropyCriterion(Layer): + def __init__(self): + super(CrossEntropyCriterion, self).__init__() + + def forward(self, predict, trg_mask, label): + cost = F.softmax_with_cross_entropy( + logits=predict, label=label, soft_label=False) + cost = paddle.squeeze(cost, axis=[2]) + masked_cost = cost * trg_mask + batch_mean_cost = paddle.reduce_mean(masked_cost, dim=[0]) + seq_cost = paddle.reduce_sum(batch_mean_cost) + return seq_cost + + +class Encoder(Layer): + def __init__(self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + pad_id=0, + dropout_prob=0., + init_scale=0.1): + super(Encoder, self).__init__() + self.embedder = Embedding( + vocab_size, + embed_dim, + pad_id, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + direction="forward", + dropout=dropout_prob if num_layers > 1 else 0., ) + + def forward(self, sequence, sequence_length): + inputs = self.embedder(sequence) + encoder_output, encoder_state = self.lstm( + inputs, sequence_length=sequence_length) + return encoder_output, encoder_state + + +class AttentionLayer(Layer): + def __init__(self, hidden_size, bias=False, init_scale=0.1): + super(AttentionLayer, self).__init__() + self.input_proj = Linear( + hidden_size, + hidden_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=bias) + self.output_proj = Linear( + hidden_size + hidden_size, + hidden_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=bias) + + def forward(self, hidden, encoder_output, encoder_padding_mask): + query = self.input_proj(hidden) + # query = hidden + # encoder_output = self.input_proj(encoder_output) + attn_scores = paddle.matmul( + paddle.unsqueeze(query, [1]), encoder_output, transpose_y=True) + if encoder_padding_mask is not None: + attn_scores = paddle.add(attn_scores, encoder_padding_mask) + attn_scores = F.softmax(attn_scores) + attn_out = paddle.squeeze( + paddle.matmul(attn_scores, encoder_output), [1]) + attn_out = paddle.concat([attn_out, hidden], 1) + attn_out = self.output_proj(attn_out) + return attn_out + + +class DecoderCell(RNNCellBase): + def __init__(self, + num_layers, + input_size, + hidden_size, + dropout_prob=0., + init_scale=0.1, + attention=True): + super(DecoderCell, self).__init__() + self.attention = attention + if dropout_prob > 0: + self.dropout = Dropout(dropout_prob) + else: + self.dropout = None + + self.lstm_cells = LayerList([ + LSTMCell( + input_size=input_size + hidden_size if i == 0 else hidden_size, + hidden_size=hidden_size) for i in range(num_layers) + ]) + self.attention_layer = AttentionLayer(hidden_size) + + def forward(self, + step_input, + states, + encoder_output=None, + encoder_padding_mask=None): + lstm_states, input_feed = states + new_lstm_states = [] + step_input = paddle.concat([step_input, input_feed], 1) + for i, lstm_cell in enumerate(self.lstm_cells): + out, new_lstm_state = lstm_cell(step_input, lstm_states[i]) + if self.dropout: + out = self.dropout(out) + step_input = out + new_lstm_states.append(new_lstm_state) + if self.attention: + out = self.attention_layer(step_input, encoder_output, + encoder_padding_mask) + else: + out = step_input + return out, [new_lstm_states, out] + + +class Decoder(Layer): + def __init__(self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + pad_id=0, + attention=True, + dropout_prob=0., + init_scale=0.1): + super(Decoder, self).__init__() + self.attention = attention + self.embedder = Embedding( + vocab_size, + embed_dim, + pad_id, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale))) + self.lstm_attention = RNN(DecoderCell(num_layers, embed_dim, + hidden_size, dropout_prob, + init_scale, attention), + is_reverse=False, + time_major=False) + self.output_layer = Linear( + hidden_size, + vocab_size, + weight_attr=paddle.ParamAttr(initializer=I.Uniform( + low=-init_scale, high=init_scale)), + bias_attr=False) + + def forward(self, + target, + decoder_initial_states, + encoder_output=None, + encoder_padding_mask=None): + inputs = self.embedder(target) + decoder_output, _ = self.lstm_attention( + inputs, + initial_states=decoder_initial_states, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + predict = self.output_layer(decoder_output) + return predict + + +class Seq2Seq(Layer): + def __init__(self, + src_vocab_size, + trg_vocab_size, + embed_dim, + hidden_size, + num_layers, + attention=True, + dropout_prob=0., + pad_id=0, + init_scale=0.1): + super(Seq2Seq, self).__init__() + self.attention = attention + self.hidden_size = hidden_size + self.num_layers = num_layers + self.pad_id = pad_id + self.encoder = Encoder(src_vocab_size, embed_dim, hidden_size, + num_layers, pad_id, dropout_prob, init_scale) + self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, + num_layers, pad_id, attention, dropout_prob, + init_scale) + + def forward(self, src, src_length, trg): + # encoder + encoder_output, encoder_final_states = self.encoder(src, src_length) + encoder_final_states = [ + (encoder_final_states[0][i], encoder_final_states[1][i]) + for i in range(self.num_layers) + ] + + # decoder initial states + decoder_initial_states = [ + encoder_final_states, + self.decoder.lstm_attention.cell.get_initial_states( + batch_ref=encoder_output, shape=[self.hidden_size]) + ] + if self.attention: + # attention mask to avoid paying attention on padddings + src_mask = (src != self.pad_id).astype(paddle.get_default_dtype()) + encoder_padding_mask = (src_mask - 1.0) * 1e9 + encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1]) + # decoder with attentioon + predict = self.decoder(trg, decoder_initial_states, encoder_output, + encoder_padding_mask) + else: + predict = self.decoder(trg, decoder_initial_states) + + trg_mask = (trg != self.pad_id).astype(paddle.get_default_dtype()) + return predict, trg_mask + + +class Seq2SeqInfer(Seq2Seq): + def __init__(self, + src_vocab_size, + trg_vocab_size, + embed_dim, + hidden_size, + num_layers, + attention=True, + dropout_prob=0., + bos_id=0, + eos_id=1, + beam_size=4, + max_out_len=256): + args = dict(locals()) + args.pop("self") + args.pop("__class__", None) # py3 + self.bos_id = args.pop("bos_id") + self.eos_id = args.pop("eos_id") + self.beam_size = args.pop("beam_size") + self.max_out_len = args.pop("max_out_len") + self.pad_id = eos_id + + super(Seq2SeqInfer, self).__init__(**args) + self.max_out_len = max_out_len + # dynamic decoder for inference + self.decoder.lstm_attention.cell.attention = attention + self.beam_search_decoder = BeamSearchDecoder( + self.decoder.lstm_attention.cell, + start_token=bos_id, + end_token=eos_id, + beam_size=beam_size, + embedding_fn=self.decoder.embedder, + output_fn=self.decoder.output_layer) + + def forward(self, src, src_length): + # encoding + encoder_output, encoder_final_state = self.encoder(src, src_length) + # decoder initial states + encoder_final_state = [ + (encoder_final_state[0][i], encoder_final_state[1][i]) + for i in range(self.num_layers) + ] + + decoder_initial_states = [ + encoder_final_state, + self.decoder.lstm_attention.cell.get_initial_states( + batch_ref=encoder_output, shape=[self.hidden_size]) + ] + # import pdb; pdb.set_trace() + if self.attention: + # attention mask to avoid paying attention on padddings + encoder_padding_mask = ( + (src != self.pad_id).astype(paddle.get_default_dtype()) - 1.0 + ) * 1e9 + encoder_padding_mask = paddle.unsqueeze(encoder_padding_mask, [1]) + # Tile the batch dimension with beam_size + encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_output, self.beam_size) + encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_padding_mask, self.beam_size) + rs, _ = dynamic_decode( + decoder=self.beam_search_decoder, + inits=decoder_initial_states, + max_step_num=self.max_out_len, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + else: + rs, _ = dynamic_decode( + decoder=self.beam_search_decoder, + inits=decoder_initial_states, + max_step_num=self.max_out_len) + return rs diff --git a/seq2seq/seq2seq_attn.py b/seq2seq/seq2seq_attn.py deleted file mode 100644 index 472efcd..0000000 --- a/seq2seq/seq2seq_attn.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -import paddle.fluid.layers as layers -from paddle.fluid import ParamAttr -from paddle.fluid.initializer import UniformInitializer -from paddle.fluid.dygraph import Embedding, Linear, Layer -from paddle.fluid.layers import BeamSearchDecoder -from paddle.text import DynamicDecode, RNN, BasicLSTMCell, RNNCell - -from seq2seq_base import Encoder - - -class AttentionLayer(Layer): - def __init__(self, hidden_size, bias=False, init_scale=0.1): - super(AttentionLayer, self).__init__() - self.input_proj = Linear( - hidden_size, - hidden_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale)), - bias_attr=bias) - self.output_proj = Linear( - hidden_size + hidden_size, - hidden_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale)), - bias_attr=bias) - - def forward(self, hidden, encoder_output, encoder_padding_mask): - # query = self.input_proj(hidden) - encoder_output = self.input_proj(encoder_output) - attn_scores = layers.matmul( - layers.unsqueeze(hidden, [1]), encoder_output, transpose_y=True) - if encoder_padding_mask is not None: - attn_scores = layers.elementwise_add(attn_scores, - encoder_padding_mask) - attn_scores = layers.softmax(attn_scores) - attn_out = layers.squeeze( - layers.matmul(attn_scores, encoder_output), [1]) - attn_out = layers.concat([attn_out, hidden], 1) - attn_out = self.output_proj(attn_out) - return attn_out - - -class DecoderCell(RNNCell): - def __init__(self, - num_layers, - input_size, - hidden_size, - dropout_prob=0., - init_scale=0.1): - super(DecoderCell, self).__init__() - self.dropout_prob = dropout_prob - # use add_sublayer to add multi-layers - self.lstm_cells = [] - for i in range(num_layers): - self.lstm_cells.append( - self.add_sublayer( - "lstm_%d" % i, - BasicLSTMCell( - input_size=input_size + hidden_size - if i == 0 else hidden_size, - hidden_size=hidden_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale))))) - self.attention_layer = AttentionLayer(hidden_size) - - def forward(self, - step_input, - states, - encoder_output, - encoder_padding_mask=None): - lstm_states, input_feed = states - new_lstm_states = [] - step_input = layers.concat([step_input, input_feed], 1) - for i, lstm_cell in enumerate(self.lstm_cells): - out, new_lstm_state = lstm_cell(step_input, lstm_states[i]) - step_input = layers.dropout( - out, - self.dropout_prob, - dropout_implementation='upscale_in_train' - ) if self.dropout_prob > 0 else out - new_lstm_states.append(new_lstm_state) - out = self.attention_layer(step_input, encoder_output, - encoder_padding_mask) - return out, [new_lstm_states, out] - - -class Decoder(Layer): - def __init__(self, - vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - init_scale=0.1): - super(Decoder, self).__init__() - self.embedder = Embedding( - size=[vocab_size, embed_dim], - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale))) - self.lstm_attention = RNN(DecoderCell( - num_layers, embed_dim, hidden_size, dropout_prob, init_scale), - is_reverse=False, - time_major=False) - self.output_layer = Linear( - hidden_size, - vocab_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale)), - bias_attr=False) - - def forward(self, target, decoder_initial_states, encoder_output, - encoder_padding_mask): - inputs = self.embedder(target) - decoder_output, _ = self.lstm_attention( - inputs, - initial_states=decoder_initial_states, - encoder_output=encoder_output, - encoder_padding_mask=encoder_padding_mask) - predict = self.output_layer(decoder_output) - return predict - - -class AttentionModel(Layer): - def __init__(self, - src_vocab_size, - trg_vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - init_scale=0.1): - super(AttentionModel, self).__init__() - self.hidden_size = hidden_size - self.encoder = Encoder(src_vocab_size, embed_dim, hidden_size, - num_layers, dropout_prob, init_scale) - self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, - num_layers, dropout_prob, init_scale) - - def forward(self, src, src_length, trg): - # encoder - encoder_output, encoder_final_state = self.encoder(src, src_length) - - # decoder initial states: use input_feed and the structure is - # [[h,c] * num_layers, input_feed], consistent with DecoderCell.states - decoder_initial_states = [ - encoder_final_state, - self.decoder.lstm_attention.cell.get_initial_states( - batch_ref=encoder_output, shape=[self.hidden_size]) - ] - # attention mask to avoid paying attention on padddings - src_mask = layers.sequence_mask( - src_length, - maxlen=layers.shape(src)[1], - dtype=encoder_output.dtype) - encoder_padding_mask = (src_mask - 1.0) * 1e9 - encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) - - # decoder with attentioon - predict = self.decoder(trg, decoder_initial_states, encoder_output, - encoder_padding_mask) - return predict - - -class AttentionInferModel(AttentionModel): - def __init__(self, - src_vocab_size, - trg_vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - bos_id=0, - eos_id=1, - beam_size=4, - max_out_len=256): - args = dict(locals()) - args.pop("self") - args.pop("__class__", None) # py3 - self.bos_id = args.pop("bos_id") - self.eos_id = args.pop("eos_id") - self.beam_size = args.pop("beam_size") - self.max_out_len = args.pop("max_out_len") - super(AttentionInferModel, self).__init__(**args) - # dynamic decoder for inference - decoder = BeamSearchDecoder( - self.decoder.lstm_attention.cell, - start_token=bos_id, - end_token=eos_id, - beam_size=beam_size, - embedding_fn=self.decoder.embedder, - output_fn=self.decoder.output_layer) - self.beam_search_decoder = DynamicDecode( - decoder, max_step_num=max_out_len, is_test=True) - - def forward(self, src, src_length): - # encoding - encoder_output, encoder_final_state = self.encoder(src, src_length) - - # decoder initial states - decoder_initial_states = [ - encoder_final_state, - self.decoder.lstm_attention.cell.get_initial_states( - batch_ref=encoder_output, shape=[self.hidden_size]) - ] - # attention mask to avoid paying attention on padddings - src_mask = layers.sequence_mask( - src_length, - maxlen=layers.shape(src)[1], - dtype=encoder_output.dtype) - encoder_padding_mask = (src_mask - 1.0) * 1e9 - encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) - - # Tile the batch dimension with beam_size - encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_output, self.beam_size) - encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_padding_mask, self.beam_size) - - # dynamic decoding with beam search - rs, _ = self.beam_search_decoder( - inits=decoder_initial_states, - encoder_output=encoder_output, - encoder_padding_mask=encoder_padding_mask) - return rs diff --git a/seq2seq/seq2seq_base.py b/seq2seq/seq2seq_base.py deleted file mode 100644 index 07a0018..0000000 --- a/seq2seq/seq2seq_base.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -import paddle.fluid.layers as layers -from paddle.fluid import ParamAttr -from paddle.fluid.initializer import UniformInitializer -from paddle.fluid.dygraph import Embedding, Linear, Layer -from paddle.fluid.layers import BeamSearchDecoder -from paddle.text import DynamicDecode, RNN, BasicLSTMCell, RNNCell - - -class CrossEntropyCriterion(Layer): - def __init__(self): - super(CrossEntropyCriterion, self).__init__() - - def forward(self, predict, trg_length, label): - # for target padding mask - mask = layers.sequence_mask( - trg_length, maxlen=layers.shape(predict)[1], dtype=predict.dtype) - - cost = layers.softmax_with_cross_entropy( - logits=predict, label=label, soft_label=False) - masked_cost = layers.elementwise_mul(cost, mask, axis=0) - batch_mean_cost = layers.reduce_mean(masked_cost, dim=[0]) - seq_cost = layers.reduce_sum(batch_mean_cost) - return seq_cost - - -class EncoderCell(RNNCell): - def __init__(self, - num_layers, - input_size, - hidden_size, - dropout_prob=0., - init_scale=0.1): - super(EncoderCell, self).__init__() - self.dropout_prob = dropout_prob - # use add_sublayer to add multi-layers - self.lstm_cells = [] - for i in range(num_layers): - self.lstm_cells.append( - self.add_sublayer( - "lstm_%d" % i, - BasicLSTMCell( - input_size=input_size if i == 0 else hidden_size, - hidden_size=hidden_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale))))) - - def forward(self, step_input, states): - new_states = [] - for i, lstm_cell in enumerate(self.lstm_cells): - out, new_state = lstm_cell(step_input, states[i]) - step_input = layers.dropout( - out, - self.dropout_prob, - dropout_implementation='upscale_in_train' - ) if self.dropout_prob > 0 else out - new_states.append(new_state) - return step_input, new_states - - @property - def state_shape(self): - return [cell.state_shape for cell in self.lstm_cells] - - -class Encoder(Layer): - def __init__(self, - vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - init_scale=0.1): - super(Encoder, self).__init__() - self.embedder = Embedding( - size=[vocab_size, embed_dim], - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale))) - self.stack_lstm = RNN(EncoderCell(num_layers, embed_dim, hidden_size, - dropout_prob, init_scale), - is_reverse=False, - time_major=False) - - def forward(self, sequence, sequence_length): - inputs = self.embedder(sequence) - encoder_output, encoder_state = self.stack_lstm( - inputs, sequence_length=sequence_length) - return encoder_output, encoder_state - - -DecoderCell = EncoderCell - - -class Decoder(Layer): - def __init__(self, - vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - init_scale=0.1): - super(Decoder, self).__init__() - self.embedder = Embedding( - size=[vocab_size, embed_dim], - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale))) - self.stack_lstm = RNN(DecoderCell(num_layers, embed_dim, hidden_size, - dropout_prob, init_scale), - is_reverse=False, - time_major=False) - self.output_layer = Linear( - hidden_size, - vocab_size, - param_attr=ParamAttr(initializer=UniformInitializer( - low=-init_scale, high=init_scale)), - bias_attr=False) - - def forward(self, target, decoder_initial_states): - inputs = self.embedder(target) - decoder_output, _ = self.stack_lstm( - inputs, initial_states=decoder_initial_states) - predict = self.output_layer(decoder_output) - return predict - - -class BaseModel(Layer): - def __init__(self, - src_vocab_size, - trg_vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - init_scale=0.1): - super(BaseModel, self).__init__() - self.hidden_size = hidden_size - self.encoder = Encoder(src_vocab_size, embed_dim, hidden_size, - num_layers, dropout_prob, init_scale) - self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, - num_layers, dropout_prob, init_scale) - - def forward(self, src, src_length, trg): - # encoder - encoder_output, encoder_final_states = self.encoder(src, src_length) - - # decoder - predict = self.decoder(trg, encoder_final_states) - return predict - - -class BaseInferModel(BaseModel): - def __init__(self, - src_vocab_size, - trg_vocab_size, - embed_dim, - hidden_size, - num_layers, - dropout_prob=0., - bos_id=0, - eos_id=1, - beam_size=4, - max_out_len=256): - args = dict(locals()) - args.pop("self") - args.pop("__class__", None) # py3 - self.bos_id = args.pop("bos_id") - self.eos_id = args.pop("eos_id") - self.beam_size = args.pop("beam_size") - self.max_out_len = args.pop("max_out_len") - super(BaseInferModel, self).__init__(**args) - # dynamic decoder for inference - decoder = BeamSearchDecoder( - self.decoder.stack_lstm.cell, - start_token=bos_id, - end_token=eos_id, - beam_size=beam_size, - embedding_fn=self.decoder.embedder, - output_fn=self.decoder.output_layer) - self.beam_search_decoder = DynamicDecode( - decoder, max_step_num=max_out_len, is_test=True) - - def forward(self, src, src_length): - # encoding - encoder_output, encoder_final_states = self.encoder(src, src_length) - # dynamic decoding with beam search - rs, _ = self.beam_search_decoder(inits=encoder_final_states) - return rs diff --git a/seq2seq/train.py b/seq2seq/train.py index 104c2f7..b78835e 100644 --- a/seq2seq/train.py +++ b/seq2seq/train.py @@ -20,23 +20,19 @@ import numpy as np import paddle -import paddle.fluid as fluid -from paddle.fluid.io import DataLoader from paddle.static import InputSpec as Input -from seq2seq_base import BaseModel, CrossEntropyCriterion -from seq2seq_attn import AttentionModel +from seq2seq import Seq2Seq, CrossEntropyCriterion from reader import create_data_loader -from utility import PPL, TrainCallback, get_model_cls +from utility import PPL, TrainCallback def do_train(args): device = paddle.set_device("gpu" if args.use_gpu else "cpu") - fluid.enable_dygraph(device) if args.eager_run else None + paddle.enable_static() if not args.eager_run else None if args.enable_ce: - fluid.default_main_program().random_seed = 102 - fluid.default_startup_program().random_seed = 102 + paddle.manual_seed(102) # define model inputs = [ @@ -47,32 +43,26 @@ def do_train(args): Input( [None, None], "int64", name="trg_word"), ] - labels = [ - Input( - [None], "int64", name="trg_length"), - Input( - [None, None, 1], "int64", name="label"), - ] + labels = [Input([None, None, 1], "int64", name="label"), ] # def dataloader - train_loader, eval_loader = create_data_loader(args, device) + [train_loader, eval_loader], pad_id = create_data_loader(args, device) - model_maker = get_model_cls( - AttentionModel) if args.attention else get_model_cls(BaseModel) model = paddle.Model( - model_maker(args.src_vocab_size, args.tar_vocab_size, args.hidden_size, - args.hidden_size, args.num_layers, args.dropout), + Seq2Seq(args.src_vocab_size, args.tar_vocab_size, args.hidden_size, + args.hidden_size, args.num_layers, args.attention, + args.dropout, pad_id), inputs=inputs, labels=labels) - grad_clip = fluid.clip.GradientClipByGlobalNorm( - clip_norm=args.max_grad_norm) - optimizer = fluid.optimizer.Adam( + grad_clip = paddle.nn.ClipGradByGlobalNorm(args.max_grad_norm) + optimizer = paddle.optimizer.Adam( learning_rate=args.learning_rate, - parameter_list=model.parameters(), + parameters=model.parameters(), grad_clip=grad_clip) ppl_metric = PPL(reset_freq=100) # ppl for every 100 batches model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric) + model.fit(train_data=train_loader, eval_data=eval_loader, epochs=args.max_epoch, diff --git a/seq2seq/train.sh b/seq2seq/train.sh new file mode 100644 index 0000000..df930c1 --- /dev/null +++ b/seq2seq/train.sh @@ -0,0 +1,20 @@ +export CUDA_VISIBLE_DEVICES="4" +python train.py \ + --src_lang en --tar_lang vi \ + --attention True \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --train_data_prefix data/en-vi/train \ + --eval_data_prefix data/en-vi/tst2012 \ + --test_data_prefix data/en-vi/tst2013 \ + --vocab_prefix data/en-vi/vocab \ + --use_gpu True \ + --max_epoch 15 \ + --model_path ./attention_models \ + --eager_run True diff --git a/seq2seq/utility.py b/seq2seq/utility.py index fc446ef..ad99d08 100644 --- a/seq2seq/utility.py +++ b/seq2seq/utility.py @@ -16,9 +16,8 @@ import functools import paddle -import paddle.fluid as fluid from paddle.metric import Metric -from paddle.text import BasicLSTMCell +from paddle.nn import LSTMCell class TrainCallback(paddle.callbacks.ProgBarLogger): @@ -58,7 +57,7 @@ def __init__(self, reset_freq=100, name=None): self.reset() def compute(self, pred, seq_length, label): - word_num = fluid.layers.reduce_sum(seq_length) + word_num = paddle.reduce_sum(seq_length) return word_num def update(self, word_num): @@ -79,22 +78,3 @@ def cal_acc_ppl(self, batch_loss, batch_size): self.total_loss += batch_loss * batch_size ppl = math.exp(self.total_loss / self.word_count) return ppl - - -def get_model_cls(model_cls): - """ - Patch for BasicLSTMCell to make `_forget_bias.stop_gradient=True` - Remove this workaround when BasicLSTMCell or recurrent_op is fixed. - """ - - @functools.wraps(model_cls.__init__) - def __lstm_patch__(self, *args, **kwargs): - self._raw_init(*args, **kwargs) - layers = self.sublayers(include_sublayers=True) - for layer in layers: - if isinstance(layer, BasicLSTMCell): - layer._forget_bias.stop_gradient = False - - model_cls._raw_init = model_cls.__init__ - model_cls.__init__ = __lstm_patch__ - return model_cls