This repository was archived by the owner on Aug 6, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy patheval.py
More file actions
111 lines (86 loc) · 3.61 KB
/
eval.py
File metadata and controls
111 lines (86 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
import argparse
from layer import QRNNLayer
from model import QRNNModel
from data.util import fopen
from data.util import load_inv_dict
import data.data_utils as data_utils
from data.data_utils import seq2words
from data.data_utils import prepare_batch
from data.data_iterator import TextIterator
use_cuda = torch.cuda.is_available()
config = {}
config['model_path'] = "model/model.pkl"
config['src_vocab'] = "source_v.json"
config['tgt_vocab'] = "source_v.json"
config['decode_input'] = "dev_source_seqs.txt"
config['decode_output'] = "decoded.txt"
def load_model(config):
if os.path.exists(config.model_path):
print 'Reloading model parameters..'
checkpoint = torch.load(config.model_path)
model = QRNNModel(QRNNLayer, checkpoint['num_layers'], checkpoint['kernel_size'],
checkpoint['hidden_size'], checkpoint['emb_size'],
checkpoint['num_enc_symbols'], checkpoint['num_dec_symbols'])
model.load_state_dict(checkpoint['state_dict'])
else:
raise ValueError('No such file:[{}]'.format(config.model_path))
for key in config.__dict__:
checkpoint[key] = config.__dict__[key]
return model, checkpoint
def decode(config):
model, config = load_model(config)
# Load source data to decode
test_set = TextIterator(source=config['decode_input'],
source_dict=config['src_vocab'],
batch_size=config['batch_size'], maxlen=None,
n_words_source=config['num_enc_symbols'],
shuffle_each_epoch=False,
sort_by_length=False,)
target_inv_dict = load_inv_dict(config['tgt_vocab'])
if use_cuda:
print 'Using gpu..'
model = model.cuda()
try:
print 'Decoding starts..'
fout = fopen(config['decode_output'], 'w')
for idx, source_seq in enumerate(test_set):
source, source_len = prepare_batch(source_seq)
preds_prev = torch.zeros(len(source), config['max_decode_step']).long()
preds_prev[:,0] += data_utils.start_token
preds = torch.zeros(len(source), config['max_decode_step']).long()
if use_cuda:
source = Variable(source.cuda())
source_len = Variable(source_len.cuda())
preds_prev = Variable(preds_prev.cuda())
preds = preds.cuda()
else:
source = Variable(source)
source_len = Variable(source_len)
preds_prev = Variable(preds_prev)
states, memories = model.encode(source, source_len)
for t in xrange(config['max_decode_step']):
# logits: [batch_size x max_decode_step, tgt_vocab_size]
_, logits = model.decode(preds_prev[:,:t+1], states, memories)
# outputs: [batch_size, max_decode_step]
outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
preds[:,t] = outputs[:,t].data
if t < config['max_decode_step'] - 1:
preds_prev[:,t+1] = outputs[:,t]
for i in xrange(len(preds)):
fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
fout.flush()
print ' {}th line decoded'.format(idx * config['batch_size'])
print 'Decoding terminated'
except IOError:
pass
finally:
fout.close()
print(config)
decode(config)
print('DONE')