from datasets import *
def test(model, enc_input, start_symbol):
# Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
enc_outputs, enc_self_attns = model.Encoder(enc_input)
dec_input = torch.zeros(1, tgt_len).type_as(enc_input.data)
next_symbol = start_symbol
for i in range(0, tgt_len):
dec_input[0][i] = next_symbol
dec_outputs, _, _ = model.Decoder(dec_input, enc_input, enc_outputs)
projected = model.projection(dec_outputs)
prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
next_word = prob.data[i]
next_symbol = next_word.item()
return dec_input
enc_inputs, dec_inputs, dec_outputs = make_data()
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
enc_inputs, _, _ = next(iter(loader))
model = torch.load('model.pth')
#这里没有model.eval()
predict_dec_input = test(model, enc_inputs[0].view(1, -1).cuda(), start_symbol=tgt_vocab["S"])
predict, _, _, _ = model(enc_inputs[0].view(1, -1).cuda(), predict_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print([src_idx2word[int(i)] for i in enc_inputs[0]], '->',
[idx2word[n.item()] for n in predict.squeeze()])
多次测试后发现有时输入同样的句子但是输出的结果不一样,是否可能因为这里缺少了model.eval()导致了在测试时模型正则化
多次测试后发现有时输入同样的句子但是输出的结果不一样,是否可能因为这里缺少了model.eval()导致了在测试时模型正则化