diff --git a/mapping_for_corenlp.txt b/mapping_for_corenlp.txt new file mode 100644 index 00000000..fae943d4 --- /dev/null +++ b/mapping_for_corenlp.txt @@ -0,0 +1,2 @@ +./raw_data_covid/small_test_tgt/covid.raw_src +./raw_data_covid/small_test_tgt/covid.raw_tgt diff --git a/raw_data/.gitignore b/raw_data/.gitignore deleted file mode 100755 index c96a04f0..00000000 --- a/raw_data/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/raw_data/temp.raw_src b/raw_data/temp.raw_src new file mode 100644 index 00000000..c6d0094c --- /dev/null +++ b/raw_data/temp.raw_src @@ -0,0 +1,2 @@ +this Terry Jones had a love of the absurd that contributed much to the anarchic humour of Monty Python's Flying Circus. His style of visual comedy, leavened with a touch of the surreal, inspired many comedians who followed him. It was on Python that he honed his directing skills, notably on Life of Brian and The Meaning of Life. A keen historian, he wrote a number of books and fronted TV documentaries on ancient and medieval history. Terence Graham Parry Jones was born in Colwyn Bay in north Wales on 1 February 1942. His grandparents ran the local amateur operatic society and staged Gilbert and Sullivan concerts on the town's pier each year His family moved to Surrey when he was four but he always felt nostalgic about his native land. "I couldn't bear it and for the longest time I wanted Wales back," he once said. "I still feel very Welsh and feel it's where I should be really." After leaving the Royal Grammar School in Guildford, where he captained the school, he went on to read English at St Edmund Hall, Oxford. However, as he put it, he "strayed into history", the subject in which he graduated. While at Oxford he wrote sketches for the Oxford Revue and performed alongside a fellow student, Michael Palin. +(CNN) An Iranian chess referee says she is frightened to return home after she was criticized online for not wearing the appropriate headscarf during an international tournament. Currently the chief adjudicator at the Women's World Chess Championship held in Russia and China, Shohreh Bayat says she fears arrest after a photograph of her was taken during the event and was then circulated online in Iran. "They are very sensitive about the hijab when we are representing Iran in international events and even sometimes they send a person with the team to control our hijab," Bayat told CNN Sport in a phone interview Tuesday. The headscarf, or the hijab, has been a mandatory part of women's dress in Iran since the 1979 Islamic revolution but, in recent years, some women have mounted opposition and staged protests about headwear rules. Bayat said she had been wearing a headscarf at the tournament but that certain camera angles had made it look like she was not. "If I come back to Iran, I think there are a few possibilities. It is highly possible that they arrest me [...] or it is possible that they invalidate my passport," added Bayat. "I think they want to make an example of me." The photographs were taken at the first stage of the chess championship in Shanghai, China, but Bayat has since flown to Vladivostok, Russia, for the second leg between Ju Wenjun and Aleksandra Goryachkina. She was left "panicked and shocked" when she became aware of the reaction in Iran after checking her phone in the hotel room. The 32-year-old said she felt helpless as websites reportedly condemned her for what some described as protesting the country's compulsory law. Subsequently, Bayat has decided to no longer wear the headscarf. "I'm not wearing it anymore because what is the point? I was just tolerating it, I don't believe in the hijab," she added. "People must be free to choose to wear what they want, and I was only wearing the hijab because I live in Iran and I had to wear it. I had no other choice." Bayat says she sought help from the country's chess federation. She says the federation told her to post an apology on her social media channels. She agreed under the condition that the federation would guarantee her safety but she said they refused. "My husband is in Iran, my parents are in Iran, all my family members are in Iran. I don't have anyone else outside of Iran. I don't know what to say, this is a very hard situation," she said. CNN contacted the Iranian Chess Federation on Tuesday but has yet to receive a response. \ No newline at end of file diff --git a/runs/cnn-abs.sh b/runs/cnn-abs.sh new file mode 100644 index 00000000..2807453f --- /dev/null +++ b/runs/cnn-abs.sh @@ -0,0 +1,114 @@ +# STEP 2 +# There's a bug in the databuilder where the file referenced is not created and populated. +# Here's the workaround: +# python command you're _supposed to be able to run: +# python src/preprocess.py \ + # --mode tokenize \ + # --raw_path ../raw_data_1 \ + # --save_path ../results \ + # --log_file ../logs/cnndm.log + +# Java command you can _actually_ run: +java edu.stanford.nlp.pipeline.StanfordCoreNLP \ + -annotators tokenize,ssplit \ + -ssplit.newlineIsSentenceBreak always \ + -filelist mapping_for_corenlp.txt \ + -outputFormat json \ + -outputDirectory ./results +# note: mapping_for_corenlp.txt is actually a file you need to make that should contain one input_file per line +# ouput in results directory + + +# STEP 3 +python src/preprocess.py \ + --mode format_to_lines \ + --raw_path results \ + --save_path json_data \ + --n_cpus 1 \ + --use_bert_basic_tokenizer false \ + --map_path urls \ + --log_file logs/format_to_lines.log + +# Output files will now be in the json directory + + +# STEP 4 +python src/preprocess.py \ + --mode format_to_bert \ + --raw_path ./json_data \ + --save_path ./bert_data \ + --lower \ + --n_cpus 1 \ + --log_file ./logs/preprocess.log + +# Output in bert_data + + +# STEP 5. Model Training +# --visible_gpus 0,1,2 \ # for multiple gpus +# --visible_gpus 0,1,2 \ # for a single gpu +python src/train.py \ + --task abs \ + --mode train \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 4 \ + --batch_size 3000 \ + --train_steps 5 \ + --accum_count 2 \ + --log_file ./logs/abs_bert_cnndm \ + --use_interval true \ + --warmup_steps 1 \ + --max_pos 512 \ + --model_path ./models \ + --bert_data_path ./bert_data/cnndm_sample + +# outputs to models (example): model_step_4.pt + +# all in one attempt mentioned in the jan 22 update +# --test_from PreSumm/models/model_step_49.pt \ + +python src/train.py \ + --task abs \ + --mode test_text \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 99 \ + --batch_size 3000 \ + --accum_count 2 \ + --log_file logs/ext_bert \ + --use_interval true \ + --warmup_steps 100 \ + --max_pos 512 \ + --train_steps 100 \ + --visible_gpus 0 \ + --model_path models/ \ + --result_path results \ + --bert_data_path bert_data_covid \ + --text_src raw_data_covid/small_test_tgt/covid.raw_src \ + --text_tgt raw_data_covid/small_test_tgt/covid.raw_tgt \ + --test_from models/model_step_148000.pt + +python src/train.py \ + --task abs \ + --mode train \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 99 \ + --batch_size 3000 \ + --accum_count 2 \ + --log_file logs/ext_bert \ + --use_interval true \ + --warmup_steps 100 \ + --max_pos 512 \ + --train_steps 100 \ + --visible_gpus 0 \ + --model_path models/ \ + --result_path results \ + --bert_data_path bert_data_covid/ \ + --text_src raw_data_covid/small_test_tgt/covid.raw_src \ + --text_tgt raw_data_covid/small_test_tgt/covid.raw_tgt \ + --test_from models/model_step_148000.pt diff --git a/runs/cnn-ext.sh b/runs/cnn-ext.sh new file mode 100644 index 00000000..7d448cd9 --- /dev/null +++ b/runs/cnn-ext.sh @@ -0,0 +1,115 @@ +# STEP 2 +# There's a bug in the databuilder where the file referenced is not created and populated. +# Here's the workaround: +# python command you're _supposed to be able to run: +# python src/preprocess.py \ + # --mode tokenize \ + # --raw_path ../raw_data_1 \ + # --save_path ../results \ + # --log_file ../logs/cnndm.log + +# Java command you can _actually_ run: +java edu.stanford.nlp.pipeline.StanfordCoreNLP \ + -annotators tokenize,ssplit \ + -ssplit.newlineIsSentenceBreak always \ + -filelist mapping_for_corenlp.txt \ + -outputFormat json \ + -outputDirectory ./results +# note: mapping_for_corenlp.txt is actually a file you need to make that should contain one input_file per line +# output in results directory + + +# STEP 3 +python src/preprocess.py \ + --mode format_to_lines \ + --raw_path results \ + --save_path json_data \ + --n_cpus 1 \ + --use_bert_basic_tokenizer false \ + --map_path urls \ + --log_file logs/cnndm.log + +# Output files will now be in the json directory + + +# STEP 4 +python src/preprocess.py \ + --mode format_to_bert \ + --raw_path ./json_data \ + --save_path ./bert_data \ + --lower \ + --n_cpus 1 \ + --log_file ./logs/preprocess.log + +# Output in bert_data + + +# STEP 5. Model Training +# --visible_gpus 0,1,2 \ # for multiple gpus +# --visible_gpus 0,1,2 \ # for a single gpu +python src/train.py \ + --task ext \ + --mode train \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 4 \ + --batch_size 3000 \ + --train_steps 5 \ + --accum_count 2 \ + --log_file ./logs/ext_bert_cnndm \ + --use_interval true \ + --warmup_steps 1 \ + --max_pos 512 \ + --model_path ./models \ + --bert_data_path ./bert_data/cnndm_sample + + + +# all in one attempt mentioned in the jan 22 update +# --test_from PreSumm/models/model_step_49.pt \ + +python src/train.py \ + --task ext \ + --mode test_text \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 99 \ + --batch_size 3000 \ + --accum_count 2 \ + --log_file logs/ext_bert_cnndm \ + --use_interval true \ + --warmup_steps 100 \ + --max_pos 512 \ + --train_steps 100 \ + --visible_gpus 0 \ + --model_path models/ \ + --result_path results \ + --bert_data_path bert_data/cnndm_sample \ + --text_src raw_data/temp.raw_src \ + --text_tgt raw_data/temp.raw_tgt \ + --test_from models/bertext_cnndm_transformer.pt + + +python src/train.py \ + --task abs \ + --mode test_text \ + --ext_dropout 0.1 \ + --lr .002\ + --report_every 50 \ + --save_checkpoint_steps 99 \ + --batch_size 3000 \ + --accum_count 2 \ + --log_file logs/ext_bert \ + --use_interval true \ + --warmup_steps 100 \ + --max_pos 512 \ + --train_steps 100 \ + --visible_gpus 0 \ + --model_path models/ \ + --result_path results \ + --bert_data_path bert_data_covid \ + --text_src raw_data_covid/covid.raw_src \ + --text_tgt raw_data_covid/covid.raw_tgt \ + --test_from models/model_step_148000.pt diff --git a/src/models/data_loader.py b/src/models/data_loader.py index 6f3e38fe..63feca7f 100644 --- a/src/models/data_loader.py +++ b/src/models/data_loader.py @@ -4,6 +4,7 @@ import random import torch +from tqdm import tqdm from others.logging import logger @@ -30,13 +31,13 @@ def __init__(self, data=None, device=None, is_test=False): tgt = torch.tensor(self._pad(pre_tgt, 0)) segs = torch.tensor(self._pad(pre_segs, 0)) - mask_src = 1 - (src == 0) - mask_tgt = 1 - (tgt == 0) + mask_src = 1 - (src == 0).float() + mask_tgt = 1 - (tgt == 0).float() clss = torch.tensor(self._pad(pre_clss, -1)) src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0)) - mask_cls = 1 - (clss == -1) + mask_cls = 1 - (clss == -1).float() clss[clss == -1] = 0 setattr(self, 'clss', clss.to(device)) setattr(self, 'mask_cls', mask_cls.to(device)) @@ -94,7 +95,7 @@ def _lazy_dataset_loader(pt_file, corpus_type): yield _lazy_dataset_loader(pt, corpus_type) -def abs_batch_size_fn(new, count): +def abs_batch_size_fn(new, count, max_ndocs_in_batch=6): src, tgt = new[0], new[1] global max_n_sents, max_n_tokens, max_size if count == 1: @@ -104,12 +105,13 @@ def abs_batch_size_fn(new, count): max_n_sents = max(max_n_sents, len(tgt)) max_size = max(max_size, max_n_sents) src_elements = count * max_size - if (count > 6): + if (count > max_ndocs_in_batch): return src_elements + 1e3 return src_elements -def ext_batch_size_fn(new, count): + +def ext_batch_size_fn(new, count, max_ndocs_in_batch=6): if (len(new) == 4): pass src, labels = new[0], new[4] @@ -185,11 +187,6 @@ def data(self): xs = self.dataset return xs - - - - - def preprocess(self, ex, is_test): src = ex['src'] tgt = ex['tgt'][:self.args.max_tgt_len][:-1]+[2] @@ -209,8 +206,6 @@ def preprocess(self, ex, is_test): clss = clss[:max_sent_id] # src_txt = src_txt[:max_sent_id] - - if(is_test): return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt else: @@ -225,13 +220,13 @@ def batch_buffer(self, data, batch_size): if(ex is None): continue minibatch.append(ex) - size_so_far = self.batch_size_fn(ex, len(minibatch)) + size_so_far = self.batch_size_fn(ex, len(minibatch), self.args.max_ndocs_in_batch) if size_so_far == batch_size: yield minibatch minibatch, size_so_far = [], 0 elif size_so_far > batch_size: yield minibatch[:-1] - minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1) + minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, len(minibatch), self.args.max_ndocs_in_batch) if minibatch: yield minibatch @@ -240,13 +235,13 @@ def batch(self, data, batch_size): minibatch, size_so_far = [], 0 for ex in data: minibatch.append(ex) - size_so_far = self.batch_size_fn(ex, len(minibatch)) + size_so_far = self.batch_size_fn(ex, len(minibatch), self.args.max_ndocs_in_batch) if size_so_far == batch_size: yield minibatch minibatch, size_so_far = [], 0 elif size_so_far > batch_size: yield minibatch[:-1] - minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1) + minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, len(minibatch), self.args.max_ndocs_in_batch) if minibatch: yield minibatch @@ -287,93 +282,77 @@ def __iter__(self): return -class TextDataloader(object): - def __init__(self, args, datasets, batch_size, - device, shuffle, is_test): - self.args = args - self.batch_size = batch_size - self.device = device - - def data(self): - if self.shuffle: - random.shuffle(self.dataset) - xs = self.dataset - return xs - - def preprocess(self, ex, is_test): - src = ex['src'] - tgt = ex['tgt'][:self.args.max_tgt_len][:-1] + [2] - src_sent_labels = ex['src_sent_labels'] - segs = ex['segs'] - if (not self.args.use_interval): - segs = [0] * len(segs) - clss = ex['clss'] - src_txt = ex['src_txt'] - tgt_txt = ex['tgt_txt'] - - end_id = [src[-1]] - src = src[:-1][:self.args.max_pos - 1] + end_id - segs = segs[:self.args.max_pos] - max_sent_id = bisect.bisect_left(clss, self.args.max_pos) - src_sent_labels = src_sent_labels[:max_sent_id] - clss = clss[:max_sent_id] - # src_txt = src_txt[:max_sent_id] - - if (is_test): - return src, tgt, segs, clss, src_sent_labels, src_txt, tgt_txt - else: - return src, tgt, segs, clss, src_sent_labels - - def batch_buffer(self, data, batch_size): - minibatch, size_so_far = [], 0 - for ex in data: - if (len(ex['src']) == 0): - continue - ex = self.preprocess(ex, self.is_test) - if (ex is None): - continue - minibatch.append(ex) - size_so_far = simple_batch_size_fn(ex, len(minibatch)) - if size_so_far == batch_size: - yield minibatch - minibatch, size_so_far = [], 0 - elif size_so_far > batch_size: - yield minibatch[:-1] - minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1) - if minibatch: - yield minibatch - def create_batches(self): - """ Create batches """ - data = self.data() - for buffer in self.batch_buffer(data, self.batch_size * 300): - if (self.args.task == 'abs'): - p_batch = sorted(buffer, key=lambda x: len(x[2])) - p_batch = sorted(p_batch, key=lambda x: len(x[1])) +def load_text(args, source_fp, target_fp, device): + from others.tokenization import BertTokenizer + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) + sep_vid = tokenizer.vocab['[SEP]'] + cls_vid = tokenizer.vocab['[CLS]'] + n_lines = len(open(source_fp).read().split('\n')) + + def _process_src(raw): + raw = raw.strip().lower() + raw = raw.replace('[cls]','[CLS]').replace('[sep]','[SEP]') + src_subtokens = tokenizer.tokenize(raw) + src_subtokens = [token.replace('##.', '[SEP]') for token in src_subtokens] + src_subtokens = ['[CLS]'] + src_subtokens + ['[SEP]'] + src_subtoken_idxs = tokenizer.convert_tokens_to_ids(src_subtokens) + src_subtoken_idxs = src_subtoken_idxs[:-1][:args.max_pos] + src_subtoken_idxs[-1] = sep_vid + _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == sep_vid] + segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))] + segments_ids = [] + segs = segs[:args.max_pos] + for i, s in enumerate(segs): + if (i % 2 == 0): + segments_ids += s * [0] else: - p_batch = sorted(buffer, key=lambda x: len(x[2])) - p_batch = batch(p_batch, self.batch_size) - - p_batch = batch(p_batch, self.batch_size) - - p_batch = list(p_batch) - if (self.shuffle): - random.shuffle(p_batch) - for b in p_batch: - if (len(b) == 0): - continue - yield b - - def __iter__(self): - while True: - self.batches = self.create_batches() - for idx, minibatch in enumerate(self.batches): - # fast-forward if loaded from state - if self._iterations_this_epoch > idx: - continue - self.iterations += 1 - self._iterations_this_epoch += 1 - batch = Batch(minibatch, self.device, self.is_test) - + segments_ids += s * [1] + + src = torch.tensor(src_subtoken_idxs)[None, :].to(device) + mask_src = (1 - (src == 0).float()).to(device) + cls_ids = [[i for i, t in enumerate(src_subtoken_idxs) if t == cls_vid]] + clss = torch.tensor(cls_ids).to(device) + mask_cls = 1 - (clss == -1).float() + clss[clss == -1] = 0 + + return src, mask_src, segments_ids, clss, mask_cls + + if(target_fp==''): + with open(source_fp) as source: + for x in tqdm(source, total=n_lines): + src, mask_src, segments_ids, clss, mask_cls = _process_src(x) + segs = torch.tensor(segments_ids)[None, :].to(device) + batch = Batch() + batch.src = src + batch.tgt = None + batch.mask_src = mask_src + batch.mask_tgt = None + batch.segs = segs + batch.src_str = [[sent.replace('[SEP]','').strip() for sent in x.split('[CLS]')]] + batch.tgt_str = [''] + batch.clss = clss + batch.mask_cls = mask_cls + + batch.batch_size=1 + yield batch + else: + with open(source_fp) as source, open(target_fp) as target: + for x, y in tqdm(zip(source, target), total=n_lines): + x = x.strip() + y = y.strip() + y = ' '.join(y.split()) + src, mask_src, segments_ids, clss, mask_cls = _process_src(x) + segs = torch.tensor(segments_ids)[None, :].to(device) + batch = Batch() + batch.src = src + batch.tgt = None + batch.mask_src = mask_src + batch.mask_tgt = None + batch.segs = segs + batch.src_str = [[sent.replace('[SEP]','').strip() for sent in x.split('[CLS]')]] + batch.tgt_str = [y] + batch.clss = clss + batch.mask_cls = mask_cls + batch.batch_size=1 yield batch - return diff --git a/src/models/decoder.py b/src/models/decoder.py index 9e2371bb..eae6fb65 100644 --- a/src/models/decoder.py +++ b/src/models/decoder.py @@ -59,9 +59,10 @@ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, * all_input `[batch_size x current_step x model_dim]` """ - dec_mask = torch.gt(tgt_pad_mask + + dec_mask = torch.gt(tgt_pad_mask.byte() + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)], 0) + input_norm = self.layer_norm_1(inputs) all_input = input_norm if previous_input is not None: diff --git a/src/models/model_builder.py b/src/models/model_builder.py index 6b420e49..6568a8e3 100644 --- a/src/models/model_builder.py +++ b/src/models/model_builder.py @@ -133,7 +133,7 @@ def forward(self, x, segs, mask): class ExtSummarizer(nn.Module): - def __init__(self, args, device, checkpoint): + def __init__(self, args, device, checkpoint, max_position_upper_lim=512): super(ExtSummarizer, self).__init__() self.args = args self.device = device @@ -147,10 +147,12 @@ def __init__(self, args, device, checkpoint): self.bert.model = BertModel(bert_config) self.ext_layer = Classifier(self.bert.model.config.hidden_size) - if(args.max_pos>512): + if(args.max_pos>max_position_upper_lim): my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size) - my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data - my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1) + my_pos_embeddings.weight.data[:max_position_upper_lim] =\ + self.bert.model.embeddings.position_embeddings.weight.data + my_pos_embeddings.weight.data[max_position_upper_lim:] =\ + self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-max_position_upper_lim,1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings @@ -202,7 +204,7 @@ def __init__(self, args, device, checkpoint=None, bert_from_extractive=None): self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) if (self.args.share_emb): - tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight) + tgt_embeddings = self.bert.model.embeddings.word_embeddings self.decoder = TransformerDecoder( self.args.dec_layers, diff --git a/src/models/neural.py b/src/models/neural.py index 5d93d3b8..ebe55fdc 100644 --- a/src/models/neural.py +++ b/src/models/neural.py @@ -411,7 +411,7 @@ def unshape(x): if mask is not None: mask = mask.unsqueeze(1).expand_as(scores) - scores = scores.masked_fill(mask, -1e18) + scores = scores.masked_fill(mask.byte(), -1e18) # 3) Apply attention dropout and compute context vectors. diff --git a/src/models/predictor.py b/src/models/predictor.py index dceac0f6..5b7f0897 100644 --- a/src/models/predictor.py +++ b/src/models/predictor.py @@ -144,8 +144,10 @@ def translate(self, for batch in data_iter: if(self.args.recall_eval): gold_tgt_len = batch.tgt.size(1) - self.min_length = gold_tgt_len + 20 - self.max_length = gold_tgt_len + 60 + # self.min_length = gold_tgt_len + 20 + # self.max_length = gold_tgt_len + 60 + self.min_length = gold_tgt_len + 10 + self.max_length = gold_tgt_len + 20 batch_data = self.translate_batch(batch) translations = self.from_batch(batch_data) @@ -154,24 +156,21 @@ def translate(self, pred_str = pred.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '').replace('[unused2]', '').strip() gold_str = gold.strip() if(self.args.recall_eval): - _pred_str = '' - gap = 1e3 - for sent in pred_str.split(''): - can_pred_str = _pred_str+ ''+sent.strip() - can_gap = math.fabs(len(_pred_str.split())-len(gold_str.split())) - # if(can_gap>=gap): - if(len(can_pred_str.split())>=len(gold_str.split())+10): - pred_str = _pred_str - break - else: - gap = can_gap - _pred_str = can_pred_str - - - - # pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) - # self.raw_can_out_file.write(' '.join(pred).strip() + '\n') - # self.raw_gold_out_file.write(' '.join(gold).strip() + '\n') + # _pred_str = '' + # for sent in pred_str.split(''): + # can_pred_str = _pred_str+ ''+sent.strip() + # can_gap = math.fabs(len(_pred_str.split())-len(gold_str.split())) + # # if(can_gap>=gap): + # if(len(can_pred_str.split())>=len(gold_str.split())+10): + # pred_str = _pred_str + # break + # else: + # _pred_str = can_pred_str + + + + pred_str = ' '.join(pred_str.split()[:len(gold_str.split())]) + self.can_out_file.write(pred_str + '\n') self.gold_out_file.write(gold_str + '\n') self.src_out_file.write(src.strip() + '\n') diff --git a/src/models/trainer_ext.py b/src/models/trainer_ext.py index cbca77c7..4d1fc584 100644 --- a/src/models/trainer_ext.py +++ b/src/models/trainer_ext.py @@ -232,32 +232,32 @@ def _block_tri(c, p): with torch.no_grad(): for batch in test_iter: src = batch.src - labels = batch.src_sent_labels segs = batch.segs clss = batch.clss mask = batch.mask_src mask_cls = batch.mask_cls - gold = [] pred = [] - if (cal_lead): selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size elif (cal_oracle): + labels = batch.src_sent_labels selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in range(batch.batch_size)] else: sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) - loss = self.loss(sent_scores, labels.float()) - loss = (loss * mask.float()).sum() - batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) - stats.update(batch_stats) - sent_scores = sent_scores + mask.float() sent_scores = sent_scores.cpu().data.numpy() selected_ids = np.argsort(-sent_scores, 1) - # selected_ids = np.sort(selected_ids,1) + + if (hasattr(batch, 'src_sent_labels')): + labels = batch.src_sent_labels + loss = self.loss(sent_scores, labels.float()) + loss = (loss * mask.float()).sum() + batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels)) + stats.update(batch_stats) + for i, idx in enumerate(selected_ids): _pred = [] if (len(batch.src_str[i]) == 0): diff --git a/src/prepro/data_builder.py b/src/prepro/data_builder.py index bb7b30e3..628a005c 100644 --- a/src/prepro/data_builder.py +++ b/src/prepro/data_builder.py @@ -344,8 +344,6 @@ def format_to_lines(args): test_files.append(f) elif (real_name in corpus_mapping['train']): train_files.append(f) - # else: - # train_files.append(f) corpora = {'train': train_files, 'valid': valid_files, 'test': test_files} for corpus_type in ['train', 'valid', 'test']: diff --git a/src/preprocess.py b/src/preprocess.py index f54d3232..e2732f91 100644 --- a/src/preprocess.py +++ b/src/preprocess.py @@ -42,30 +42,30 @@ def str2bool(v): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("-pretrained_model", default='bert', type=str) + parser.add_argument("--pretrained_model", default='bert', type=str) - parser.add_argument("-mode", default='', type=str) - parser.add_argument("-select_mode", default='greedy', type=str) - parser.add_argument("-map_path", default='../../data/') - parser.add_argument("-raw_path", default='../../line_data') - parser.add_argument("-save_path", default='../../data/') + parser.add_argument("--mode", default='', type=str) + parser.add_argument("--select_mode", default='greedy', type=str) + parser.add_argument("--map_path", default='../../data/') + parser.add_argument("--raw_path", default='../../line_data') + parser.add_argument("--save_path", default='../../data/') - parser.add_argument("-shard_size", default=2000, type=int) - parser.add_argument('-min_src_nsents', default=3, type=int) - parser.add_argument('-max_src_nsents', default=100, type=int) - parser.add_argument('-min_src_ntokens_per_sent', default=5, type=int) - parser.add_argument('-max_src_ntokens_per_sent', default=200, type=int) - parser.add_argument('-min_tgt_ntokens', default=5, type=int) - parser.add_argument('-max_tgt_ntokens', default=500, type=int) + parser.add_argument("--shard_size", default=2000, type=int) + parser.add_argument('--min_src_nsents', default=3, type=int) + parser.add_argument('--max_src_nsents', default=100, type=int) + parser.add_argument('--min_src_ntokens_per_sent', default=5, type=int) + parser.add_argument('--max_src_ntokens_per_sent', default=200, type=int) + parser.add_argument('--min_tgt_ntokens', default=5, type=int) + parser.add_argument('--max_tgt_ntokens', default=500, type=int) - parser.add_argument("-lower", type=str2bool, nargs='?',const=True,default=True) - parser.add_argument("-use_bert_basic_tokenizer", type=str2bool, nargs='?',const=True,default=False) + parser.add_argument("--lower", type=str2bool, nargs='?',const=True,default=True) + parser.add_argument("--use_bert_basic_tokenizer", type=str2bool, nargs='?',const=True,default=False) - parser.add_argument('-log_file', default='../../logs/cnndm.log') + parser.add_argument('--log_file', default='../../logs/cnndm.log') - parser.add_argument('-dataset', default='') + parser.add_argument('--dataset', default='') - parser.add_argument('-n_cpus', default=2, type=int) + parser.add_argument('--n_cpus', default=2, type=int) args = parser.parse_args() diff --git a/src/tags b/src/tags new file mode 100644 index 00000000..1744951d --- /dev/null +++ b/src/tags @@ -0,0 +1,625 @@ +!_TAG_FILE_FORMAT 2 /extended format; --format=1 will not append ;" to lines/ +!_TAG_FILE_SORTED 1 /0=unsorted, 1=sorted, 2=foldcase/ +!_TAG_PROGRAM_AUTHOR Darren Hiebert /dhiebert@users.sourceforge.net/ +!_TAG_PROGRAM_NAME Exuberant Ctags // +!_TAG_PROGRAM_URL http://ctags.sourceforge.net /official site/ +!_TAG_PROGRAM_VERSION 5.8 // +AbsSummarizer models/model_builder.py /^class AbsSummarizer(nn.Module):$/;" c +AbsSummarizer train_abstractive.py /^from models.model_builder import AbsSummarizer$/;" i +Adam models/adam.py /^class Adam(Optimizer):$/;" c +BasicTokenizer others/tokenization.py /^class BasicTokenizer(object):$/;" c +Batch models/data_loader.py /^class Batch(object):$/;" c +Beam translate/beam.py /^class Beam(object):$/;" c +Bert models/model_builder.py /^class Bert(nn.Module):$/;" c +BertConfig models/model_builder.py /^from pytorch_transformers import BertModel, BertConfig$/;" i +BertData prepro/data_builder.py /^class BertData():$/;" c +BertModel models/model_builder.py /^from pytorch_transformers import BertModel, BertConfig$/;" i +BertTokenizer models/data_loader.py /^ from others.tokenization import BertTokenizer$/;" i +BertTokenizer others/tokenization.py /^class BertTokenizer(object):$/;" c +BertTokenizer prepro/data_builder.py /^from others.tokenization import BertTokenizer$/;" i +BertTokenizer train_abstractive.py /^from pytorch_transformers import BertTokenizer$/;" i +Classifier models/encoder.py /^class Classifier(nn.Module):$/;" c +Classifier models/model_builder.py /^from models.encoder import Classifier, ExtTransformerEncoder$/;" i +ConfigParser others/pyrouge.py /^ from ConfigParser import ConfigParser$/;" i +ConfigParser others/pyrouge.py /^ from configparser import ConfigParser$/;" i +Counter prepro/data_builder.py /^from collections import Counter$/;" i +DataIterator models/data_loader.py /^class DataIterator(object):$/;" c +Dataloader models/data_loader.py /^class Dataloader(object):$/;" c +DecoderState models/decoder.py /^from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState$/;" i +DecoderState models/neural.py /^class DecoderState(object):$/;" c +DirectoryProcessor others/pyrouge.py /^class DirectoryProcessor:$/;" c +ET prepro/data_builder.py /^import xml.etree.ElementTree as ET$/;" i +ErrorHandler train_abstractive.py /^class ErrorHandler(object):$/;" c +ErrorHandler train_extractive.py /^class ErrorHandler(object):$/;" c +ExtSummarizer models/model_builder.py /^class ExtSummarizer(nn.Module):$/;" c +ExtSummarizer train_extractive.py /^from models.model_builder import ExtSummarizer$/;" i +ExtTransformerEncoder models/encoder.py /^class ExtTransformerEncoder(nn.Module):$/;" c +ExtTransformerEncoder models/model_builder.py /^from models.encoder import Classifier, ExtTransformerEncoder$/;" i +F models/loss.py /^import torch.nn.functional as F$/;" i +F models/neural.py /^import torch.nn.functional as F$/;" i +GNMTGlobalScorer models/predictor.py /^from translate.beam import GNMTGlobalScorer$/;" i +GNMTGlobalScorer translate/beam.py /^class GNMTGlobalScorer(object):$/;" c +GlobalAttention models/neural.py /^class GlobalAttention(nn.Module):$/;" c +LabelSmoothingLoss models/loss.py /^class LabelSmoothingLoss(nn.Module):$/;" c +LossComputeBase models/loss.py /^class LossComputeBase(nn.Module):$/;" c +MAX_SIZE models/decoder.py /^MAX_SIZE = 5000$/;" v +MultiHeadedAttention models/decoder.py /^from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState$/;" i +MultiHeadedAttention models/encoder.py /^from models.neural import MultiHeadedAttention, PositionwiseFeedForward$/;" i +MultiHeadedAttention models/neural.py /^class MultiHeadedAttention(nn.Module):$/;" c +MultipleOptimizer models/optimizers.py /^class MultipleOptimizer(object):$/;" c +NMTLossCompute models/loss.py /^class NMTLossCompute(LossComputeBase):$/;" c +Optimizer models/adam.py /^from torch.optim.optimizer import Optimizer$/;" i +Optimizer models/model_builder.py /^from models.optimizers import Optimizer$/;" i +Optimizer models/optimizers.py /^class Optimizer(object):$/;" c +PRETRAINED_VOCAB_ARCHIVE_MAP others/tokenization.py /^PRETRAINED_VOCAB_ARCHIVE_MAP = {$/;" v +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP others/tokenization.py /^PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {$/;" v +PenaltyBuilder translate/penalties.py /^class PenaltyBuilder(object):$/;" c +Pool cal_rouge.py /^from multiprocessing import Pool$/;" i +Pool prepro/data_builder.py /^from multiprocess import Pool$/;" i +PositionalEncoding models/decoder.py /^from models.encoder import PositionalEncoding$/;" i +PositionalEncoding models/encoder.py /^class PositionalEncoding(nn.Module):$/;" c +PositionwiseFeedForward models/decoder.py /^from models.neural import MultiHeadedAttention, PositionwiseFeedForward, DecoderState$/;" i +PositionwiseFeedForward models/encoder.py /^from models.neural import MultiHeadedAttention, PositionwiseFeedForward$/;" i +PositionwiseFeedForward models/neural.py /^class PositionwiseFeedForward(nn.Module):$/;" c +PunktSentenceSplitter others/pyrouge.py /^ from pyrouge.utils.sentence_splitter import PunktSentenceSplitter$/;" i +REMAP others/pyrouge.py /^REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",$/;" v +REMAP others/utils.py /^REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",$/;" v +ReportMgr models/reporter.py /^class ReportMgr(ReportMgrBase):$/;" c +ReportMgr models/reporter_ext.py /^class ReportMgr(ReportMgrBase):$/;" c +ReportMgr models/trainer.py /^from models.reporter import ReportMgr, Statistics$/;" i +ReportMgr models/trainer_ext.py /^from models.reporter_ext import ReportMgr, Statistics$/;" i +ReportMgrBase models/reporter.py /^class ReportMgrBase(object):$/;" c +ReportMgrBase models/reporter_ext.py /^class ReportMgrBase(object):$/;" c +Rouge155 others/pyrouge.py /^class Rouge155(object):$/;" c +Statistics models/loss.py /^from models.reporter import Statistics$/;" i +Statistics models/reporter.py /^class Statistics(object):$/;" c +Statistics models/reporter_ext.py /^class Statistics(object):$/;" c +Statistics models/trainer.py /^from models.reporter import ReportMgr, Statistics$/;" i +Statistics models/trainer_ext.py /^from models.reporter_ext import ReportMgr, Statistics$/;" i +SummaryWriter models/predictor.py /^from tensorboardX import SummaryWriter$/;" i +SummaryWriter models/reporter.py /^ from tensorboardX import SummaryWriter$/;" i +SummaryWriter models/reporter_ext.py /^ from tensorboardX import SummaryWriter$/;" i +SummaryWriter models/trainer.py /^from tensorboardX import SummaryWriter$/;" i +SummaryWriter models/trainer_ext.py /^from tensorboardX import SummaryWriter$/;" i +Trainer models/trainer.py /^class Trainer(object):$/;" c +Trainer models/trainer_ext.py /^class Trainer(object):$/;" c +TransformerDecoder models/decoder.py /^class TransformerDecoder(nn.Module):$/;" c +TransformerDecoder models/model_builder.py /^from models.decoder import TransformerDecoder$/;" i +TransformerDecoderLayer models/decoder.py /^class TransformerDecoderLayer(nn.Module):$/;" c +TransformerDecoderState models/decoder.py /^class TransformerDecoderState(DecoderState):$/;" c +TransformerEncoderLayer models/encoder.py /^class TransformerEncoderLayer(nn.Module):$/;" c +Translation models/predictor.py /^class Translation(object):$/;" c +Translator models/predictor.py /^class Translator(object):$/;" c +VOCAB_NAME others/tokenization.py /^VOCAB_NAME = 'vocab.txt'$/;" v +WordpieceTokenizer others/tokenization.py /^class WordpieceTokenizer(object):$/;" c +XLNetTokenizer prepro/data_builder.py /^from pytorch_transformers import XLNetTokenizer$/;" i +__add_config_option others/pyrouge.py /^ def __add_config_option(self, options):$/;" m class:Rouge155 file: +__clean_rouge_args others/pyrouge.py /^ def __clean_rouge_args(self, rouge_args):$/;" m class:Rouge155 file: +__create_dir_property others/pyrouge.py /^ def __create_dir_property(self, dir_name, docstring):$/;" m class:Rouge155 file: +__get_config_path others/pyrouge.py /^ def __get_config_path(self):$/;" m class:Rouge155 file: +__get_eval_string others/pyrouge.py /^ def __get_eval_string($/;" m class:Rouge155 file: +__get_model_filenames_for_id others/pyrouge.py /^ def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern):$/;" m class:Rouge155 file: +__get_options others/pyrouge.py /^ def __get_options(self, rouge_args=None):$/;" m class:Rouge155 file: +__get_rouge_home_dir_from_settings others/pyrouge.py /^ def __get_rouge_home_dir_from_settings(self):$/;" m class:Rouge155 file: +__init__ models/adam.py /^ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,$/;" m class:Adam +__init__ models/data_loader.py /^ def __init__(self, args, dataset, batch_size, device=None, is_test=False,$/;" m class:DataIterator +__init__ models/data_loader.py /^ def __init__(self, args, datasets, batch_size,$/;" m class:Dataloader +__init__ models/data_loader.py /^ def __init__(self, data=None, device=None, is_test=False):$/;" m class:Batch +__init__ models/decoder.py /^ def __init__(self, d_model, heads, d_ff, dropout):$/;" m class:TransformerDecoderLayer +__init__ models/decoder.py /^ def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):$/;" m class:TransformerDecoder +__init__ models/decoder.py /^ def __init__(self, src):$/;" m class:TransformerDecoderState +__init__ models/encoder.py /^ def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0):$/;" m class:ExtTransformerEncoder +__init__ models/encoder.py /^ def __init__(self, d_model, heads, d_ff, dropout):$/;" m class:TransformerEncoderLayer +__init__ models/encoder.py /^ def __init__(self, dropout, dim, max_len=5000):$/;" m class:PositionalEncoding +__init__ models/encoder.py /^ def __init__(self, hidden_size):$/;" m class:Classifier +__init__ models/loss.py /^ def __init__(self, generator, pad_id):$/;" m class:LossComputeBase +__init__ models/loss.py /^ def __init__(self, generator, symbols, vocab_size,$/;" m class:NMTLossCompute +__init__ models/loss.py /^ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):$/;" m class:LabelSmoothingLoss +__init__ models/model_builder.py /^ def __init__(self, args, device, checkpoint):$/;" m class:ExtSummarizer +__init__ models/model_builder.py /^ def __init__(self, args, device, checkpoint=None, bert_from_extractive=None):$/;" m class:AbsSummarizer +__init__ models/model_builder.py /^ def __init__(self, large, temp_dir, finetune=False):$/;" m class:Bert +__init__ models/neural.py /^ def __init__(self, d_model, d_ff, dropout=0.1):$/;" m class:PositionwiseFeedForward +__init__ models/neural.py /^ def __init__(self, dim, attn_type="dot"):$/;" m class:GlobalAttention +__init__ models/neural.py /^ def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):$/;" m class:MultiHeadedAttention +__init__ models/optimizers.py /^ def __init__(self, method, learning_rate, max_grad_norm,$/;" m class:Optimizer +__init__ models/optimizers.py /^ def __init__(self, op):$/;" m class:MultipleOptimizer +__init__ models/predictor.py /^ def __init__(self, fname, src, src_raw, pred_sents,$/;" m class:Translation +__init__ models/predictor.py /^ def __init__(self,$/;" m class:Translator +__init__ models/reporter.py /^ def __init__(self, loss=0, n_words=0, n_correct=0):$/;" m class:Statistics +__init__ models/reporter.py /^ def __init__(self, report_every, start_time=-1.):$/;" m class:ReportMgrBase +__init__ models/reporter.py /^ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):$/;" m class:ReportMgr +__init__ models/reporter_ext.py /^ def __init__(self, loss=0, n_docs=0, n_correct=0):$/;" m class:Statistics +__init__ models/reporter_ext.py /^ def __init__(self, report_every, start_time=-1.):$/;" m class:ReportMgrBase +__init__ models/reporter_ext.py /^ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):$/;" m class:ReportMgr +__init__ models/trainer.py /^ def __init__(self, args, model, optims, loss,$/;" m class:Trainer +__init__ models/trainer_ext.py /^ def __init__(self, args, model, optim,$/;" m class:Trainer +__init__ others/pyrouge.py /^ def __init__(self, rouge_dir=None, rouge_args=None, temp_dir = None):$/;" m class:Rouge155 +__init__ others/tokenization.py /^ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):$/;" m class:WordpieceTokenizer +__init__ others/tokenization.py /^ def __init__(self, vocab_file, do_lower_case=True, max_len=None,$/;" m class:BertTokenizer +__init__ others/tokenization.py /^ def __init__(self,$/;" m class:BasicTokenizer +__init__ prepro/data_builder.py /^ def __init__(self, args):$/;" m class:BertData +__init__ train_abstractive.py /^ def __init__(self, error_queue):$/;" m class:ErrorHandler +__init__ train_extractive.py /^ def __init__(self, error_queue):$/;" m class:ErrorHandler +__init__ translate/beam.py /^ def __init__(self, alpha, length_penalty):$/;" m class:GNMTGlobalScorer +__init__ translate/beam.py /^ def __init__(self, size, pad, bos, eos,$/;" m class:Beam +__init__ translate/penalties.py /^ def __init__(self, length_pen):$/;" m class:PenaltyBuilder +__iter__ models/data_loader.py /^ def __iter__(self):$/;" m class:DataIterator file: +__iter__ models/data_loader.py /^ def __iter__(self):$/;" m class:Dataloader file: +__len__ models/data_loader.py /^ def __len__(self):$/;" m class:Batch file: +__process_summaries others/pyrouge.py /^ def __process_summaries(self, process_func):$/;" m class:Rouge155 file: +__set_dir_properties others/pyrouge.py /^ def __set_dir_properties(self):$/;" m class:Rouge155 file: +__set_rouge_dir others/pyrouge.py /^ def __set_rouge_dir(self, home_dir=None):$/;" m class:Rouge155 file: +__setstate__ models/adam.py /^ def __setstate__(self, state):$/;" m class:Adam file: +__write_summaries others/pyrouge.py /^ def __write_summaries(self):$/;" m class:Rouge155 file: +_all models/decoder.py /^ def _all(self):$/;" m class:TransformerDecoderState +_block_tri models/trainer.py /^ def _block_tri(c, p):$/;" f function:Trainer.test +_block_tri models/trainer_ext.py /^ def _block_tri(c, p):$/;" f function:Trainer.test +_bottle models/loss.py /^ def _bottle(self, _v):$/;" m class:LossComputeBase +_build_target_tokens models/predictor.py /^ def _build_target_tokens(self, pred):$/;" m class:Translator +_clean_text others/tokenization.py /^ def _clean_text(self, text):$/;" m class:BasicTokenizer +_compute_loss models/loss.py /^ def _compute_loss(self, batch, output, target):$/;" m class:NMTLossCompute +_compute_loss models/loss.py /^ def _compute_loss(self, batch, output, target, **kwargs):$/;" m class:LossComputeBase +_fast_translate_batch models/predictor.py /^ def _fast_translate_batch(self,$/;" m class:Translator +_format_to_bert prepro/data_builder.py /^def _format_to_bert(params):$/;" f +_format_to_lines prepro/data_builder.py /^def _format_to_lines(params):$/;" f +_format_xsum_to_lines prepro/data_builder.py /^def _format_xsum_to_lines(params):$/;" f +_get_attn_subsequent_mask models/decoder.py /^ def _get_attn_subsequent_mask(self, size):$/;" m class:TransformerDecoderLayer +_get_ngrams models/trainer.py /^ def _get_ngrams(n, text):$/;" f function:Trainer.test +_get_ngrams models/trainer_ext.py /^ def _get_ngrams(n, text):$/;" f function:Trainer.test +_get_ngrams prepro/utils.py /^def _get_ngrams(n, text):$/;" f +_get_word_ngrams prepro/data_builder.py /^from prepro.utils import _get_word_ngrams$/;" i +_get_word_ngrams prepro/utils.py /^def _get_word_ngrams(n, sentences):$/;" f +_gradient_accumulation models/trainer.py /^ def _gradient_accumulation(self, true_batchs, normalization, total_stats,$/;" m class:Trainer +_gradient_accumulation models/trainer_ext.py /^ def _gradient_accumulation(self, true_batchs, normalization, total_stats,$/;" m class:Trainer +_init_cache models/decoder.py /^ def _init_cache(self, memory_bank, num_layers):$/;" m class:TransformerDecoderState +_is_chinese_char others/tokenization.py /^ def _is_chinese_char(self, cp):$/;" m class:BasicTokenizer +_is_control others/tokenization.py /^def _is_control(char):$/;" f +_is_punctuation others/tokenization.py /^def _is_punctuation(char):$/;" f +_is_whitespace others/tokenization.py /^def _is_whitespace(char):$/;" f +_lazy_dataset_loader models/data_loader.py /^ def _lazy_dataset_loader(pt_file, corpus_type):$/;" f function:load_dataset +_make_shard_state models/loss.py /^ def _make_shard_state(self, batch, output):$/;" m class:NMTLossCompute +_make_shard_state models/loss.py /^ def _make_shard_state(self, batch, output, attns=None):$/;" m class:LossComputeBase +_maybe_gather_stats models/trainer.py /^ def _maybe_gather_stats(self, stat):$/;" m class:Trainer +_maybe_gather_stats models/trainer_ext.py /^ def _maybe_gather_stats(self, stat):$/;" m class:Trainer +_maybe_report_training models/trainer.py /^ def _maybe_report_training(self, step, num_steps, learning_rate,$/;" m class:Trainer +_maybe_report_training models/trainer_ext.py /^ def _maybe_report_training(self, step, num_steps, learning_rate,$/;" m class:Trainer +_maybe_save models/trainer.py /^ def _maybe_save(self, step):$/;" m class:Trainer +_maybe_save models/trainer_ext.py /^ def _maybe_save(self, step):$/;" m class:Trainer +_next_dataset_iterator models/data_loader.py /^ def _next_dataset_iterator(self, dataset_iter):$/;" m class:Dataloader +_pad models/data_loader.py /^ def _pad(self, data, pad_id, width=-1):$/;" m class:Batch +_process_src models/data_loader.py /^ def _process_src(raw):$/;" f function:load_text +_recursive_map models/decoder.py /^ def _recursive_map(struct, batch_dim=0):$/;" f function:TransformerDecoderState.map_batch_fn +_report_rouge models/predictor.py /^ def _report_rouge(self, gold_path, can_path):$/;" m class:Translator +_report_step models/reporter.py /^ def _report_step(self, *args, **kwargs):$/;" m class:ReportMgrBase +_report_step models/reporter.py /^ def _report_step(self, lr, step, train_stats=None, valid_stats=None):$/;" m class:ReportMgr +_report_step models/reporter_ext.py /^ def _report_step(self, *args, **kwargs):$/;" m class:ReportMgrBase +_report_step models/reporter_ext.py /^ def _report_step(self, lr, step, train_stats=None, valid_stats=None):$/;" m class:ReportMgr +_report_step models/trainer.py /^ def _report_step(self, learning_rate, step, train_stats=None,$/;" m class:Trainer +_report_step models/trainer_ext.py /^ def _report_step(self, learning_rate, step, train_stats=None,$/;" m class:Trainer +_report_training models/reporter.py /^ def _report_training(self, *args, **kwargs):$/;" m class:ReportMgrBase +_report_training models/reporter.py /^ def _report_training(self, step, num_steps, learning_rate,$/;" m class:ReportMgr +_report_training models/reporter_ext.py /^ def _report_training(self, *args, **kwargs):$/;" m class:ReportMgrBase +_report_training models/reporter_ext.py /^ def _report_training(self, step, num_steps, learning_rate,$/;" m class:ReportMgr +_rouge_clean prepro/data_builder.py /^ def _rouge_clean(s):$/;" f function:greedy_selection +_run_split_on_punc others/tokenization.py /^ def _run_split_on_punc(self, text):$/;" m class:BasicTokenizer +_run_strip_accents others/tokenization.py /^ def _run_strip_accents(self, text):$/;" m class:BasicTokenizer +_save models/trainer.py /^ def _save(self, step):$/;" m class:Trainer +_save models/trainer_ext.py /^ def _save(self, step):$/;" m class:Trainer +_set_rate models/optimizers.py /^ def _set_rate(self, learning_rate):$/;" m class:Optimizer +_start_report_manager models/trainer.py /^ def _start_report_manager(self, start_time=None):$/;" m class:Trainer +_start_report_manager models/trainer_ext.py /^ def _start_report_manager(self, start_time=None):$/;" m class:Trainer +_stats models/loss.py /^ def _stats(self, loss, scores, target):$/;" m class:LossComputeBase +_tally_parameters models/trainer.py /^def _tally_parameters(model):$/;" f +_tally_parameters models/trainer_ext.py /^def _tally_parameters(model):$/;" f +_tokenize_chinese_chars others/tokenization.py /^ def _tokenize_chinese_chars(self, text):$/;" m class:BasicTokenizer +_unbottle models/loss.py /^ def _unbottle(self, _v, batch_size):$/;" m class:LossComputeBase +abs_batch_size_fn models/data_loader.py /^def abs_batch_size_fn(new, count, max_ndocs_in_batch=6):$/;" f +abs_loss models/loss.py /^def abs_loss(generator, symbols, vocab_size, device, train=True, label_smoothing=0.0):$/;" f +abs_loss train_abstractive.py /^from models.loss import abs_loss$/;" i +absolute_import others/logging.py /^from __future__ import absolute_import$/;" i +absolute_import others/tokenization.py /^from __future__ import absolute_import, division, print_function, unicode_literals$/;" i +accuracy models/reporter.py /^ def accuracy(self):$/;" m class:Statistics +add_child train_abstractive.py /^ def add_child(self, pid):$/;" m class:ErrorHandler +add_child train_extractive.py /^ def add_child(self, pid):$/;" m class:ErrorHandler +advance translate/beam.py /^ def advance(self, word_probs, attn_out):$/;" m class:Beam +aeq models/neural.py /^def aeq(*args):$/;" f +all_gather_list distributed.py /^def all_gather_list(data, max_size=4096):$/;" f +all_gather_list models/reporter.py /^from distributed import all_gather_list$/;" i +all_gather_list models/reporter_ext.py /^ from distributed import all_gather_list$/;" i +all_gather_stats models/reporter.py /^ def all_gather_stats(stat, max_size=4096):$/;" m class:Statistics +all_gather_stats models/reporter_ext.py /^ def all_gather_stats(stat, max_size=4096):$/;" m class:Statistics +all_gather_stats_list models/reporter.py /^ def all_gather_stats_list(stat_list, max_size=4096):$/;" m class:Statistics +all_gather_stats_list models/reporter_ext.py /^ def all_gather_stats_list(stat_list, max_size=4096):$/;" m class:Statistics +all_reduce_and_rescale_tensors distributed.py /^def all_reduce_and_rescale_tensors(tensors, rescale_denom,$/;" f +all_reduce_buffer distributed.py /^ def all_reduce_buffer():$/;" f function:all_reduce_and_rescale_tensors +argparse cal_rouge.py /^import argparse$/;" i +argparse others/pyrouge.py /^ import argparse$/;" i +argparse post_stats.py /^import argparse$/;" i +argparse preprocess.py /^import argparse$/;" i +argparse train.py /^import argparse$/;" i +argparse train_abstractive.py /^import argparse$/;" i +argparse train_extractive.py /^import argparse$/;" i +args others/pyrouge.py /^ args = parser.parse_args()$/;" v class:Rouge155 +args post_stats.py /^ args = parser.parse_args()$/;" v +baseline train.py /^from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs$/;" i +baseline train_abstractive.py /^def baseline(args, cal_lead=False, cal_oracle=False):$/;" f +batch models/data_loader.py /^ def batch(self, data, batch_size):$/;" m class:DataIterator +batch_buffer models/data_loader.py /^ def batch_buffer(self, data, batch_size):$/;" m class:DataIterator +beam_update models/neural.py /^ def beam_update(self, idx, positions, beam_size):$/;" m class:DecoderState +bin_path others/pyrouge.py /^ def bin_path(self):$/;" m class:Rouge155 +bisect models/data_loader.py /^import bisect$/;" i +build_optim models/model_builder.py /^def build_optim(args, model, checkpoint):$/;" f +build_optim models/optimizers.py /^def build_optim(model, opt, checkpoint):$/;" f +build_optim_bert models/model_builder.py /^def build_optim_bert(args, model, checkpoint):$/;" f +build_optim_dec models/model_builder.py /^def build_optim_dec(args, model, checkpoint):$/;" f +build_predictor models/predictor.py /^def build_predictor(args, tokenizer, symbols, model, logger=None):$/;" f +build_predictor train_abstractive.py /^from models.predictor import build_predictor$/;" i +build_report_manager models/reporter.py /^def build_report_manager(opt):$/;" f +build_report_manager models/reporter_ext.py /^def build_report_manager(opt):$/;" f +build_trainer models/trainer.py /^def build_trainer(args, device_id, model, optims,loss):$/;" f +build_trainer models/trainer_ext.py /^def build_trainer(args, device_id, model, optim):$/;" f +build_trainer train_abstractive.py /^from models.trainer import build_trainer$/;" i +build_trainer train_extractive.py /^from models.trainer_ext import build_trainer$/;" i +cached_path others/tokenization.py /^from pytorch_transformers import cached_path$/;" i +cal_novel post_stats.py /^def cal_novel(summary, gold, source, summary_ngram_novel, gold_ngram_novel):$/;" f +cal_repeat post_stats.py /^def cal_repeat(args):$/;" f +cal_rouge prepro/data_builder.py /^def cal_rouge(evaluated_ngrams, reference_ngrams):$/;" f +cal_self_repeat post_stats.py /^def cal_self_repeat(summary):$/;" f +check_output others/pyrouge.py /^from subprocess import check_output$/;" i +chunks cal_rouge.py /^def chunks(l, n):$/;" f +clean others/pyrouge.py /^def clean(x):$/;" f +clean others/utils.py /^def clean(x):$/;" f +clean prepro/data_builder.py /^from others.utils import clean$/;" i +clip_grad_norm_ models/optimizers.py /^from torch.nn.utils import clip_grad_norm_$/;" i +codecs cal_rouge.py /^import codecs$/;" i +codecs models/predictor.py /^import codecs$/;" i +codecs others/pyrouge.py /^import codecs$/;" i +collections others/tokenization.py /^import collections$/;" i +config_file others/pyrouge.py /^ def config_file(self):$/;" m class:Rouge155 +config_file others/pyrouge.py /^ def config_file(self, path):$/;" m class:Rouge155 +convert_and_evaluate others/pyrouge.py /^ def convert_and_evaluate(self, system_id=1,$/;" m class:Rouge155 +convert_ids_to_tokens others/tokenization.py /^ def convert_ids_to_tokens(self, ids):$/;" m class:BertTokenizer +convert_summaries_to_rouge_format others/pyrouge.py /^ def convert_summaries_to_rouge_format(input_dir, output_dir):$/;" m class:Rouge155 +convert_text_to_rouge_format others/pyrouge.py /^ def convert_text_to_rouge_format(text, title="dummy title"):$/;" m class:Rouge155 +convert_tokens_to_ids others/tokenization.py /^ def convert_tokens_to_ids(self, tokens):$/;" m class:BertTokenizer +copy models/model_builder.py /^import copy$/;" i +create_batches models/data_loader.py /^ def create_batches(self):$/;" m class:DataIterator +data models/data_loader.py /^ def data(self):$/;" m class:DataIterator +data_builder preprocess.py /^from prepro import data_builder$/;" i +data_loader train_abstractive.py /^from models import data_loader, model_builder$/;" i +data_loader train_extractive.py /^from models import data_loader, model_builder$/;" i +datetime models/reporter.py /^from datetime import datetime$/;" i +datetime models/reporter_ext.py /^from datetime import datetime$/;" i +detach models/decoder.py /^ def detach(self):$/;" m class:TransformerDecoderState +detach models/neural.py /^ def detach(self):$/;" m class:DecoderState +distributed distributed.py /^import torch.distributed$/;" i +distributed models/trainer.py /^import distributed$/;" i +distributed models/trainer_ext.py /^import distributed$/;" i +distributed train_abstractive.py /^import distributed$/;" i +distributed train_extractive.py /^import distributed$/;" i +division models/loss.py /^from __future__ import division$/;" i +division others/pyrouge.py /^from __future__ import print_function, unicode_literals, division$/;" i +division others/tokenization.py /^from __future__ import absolute_import, division, print_function, unicode_literals$/;" i +division train.py /^from __future__ import division$/;" i +division train_abstractive.py /^from __future__ import division$/;" i +division train_extractive.py /^from __future__ import division$/;" i +division translate/beam.py /^from __future__ import division$/;" i +division translate/penalties.py /^from __future__ import division$/;" i +do_format_to_bert preprocess.py /^def do_format_to_bert(args):$/;" f +do_format_to_lines preprocess.py /^def do_format_to_lines(args):$/;" f +do_format_xsum_to_lines preprocess.py /^def do_format_xsum_to_lines(args):$/;" f +do_tokenize preprocess.py /^def do_tokenize(args):$/;" f +done translate/beam.py /^ def done(self):$/;" m class:Beam +elapsed_time models/reporter.py /^ def elapsed_time(self):$/;" m class:Statistics +elapsed_time models/reporter_ext.py /^ def elapsed_time(self):$/;" m class:Statistics +error_listener train_abstractive.py /^ def error_listener(self):$/;" m class:ErrorHandler +error_listener train_extractive.py /^ def error_listener(self):$/;" m class:ErrorHandler +etree prepro/data_builder.py /^import xml.etree.ElementTree as ET$/;" i +evaluate others/pyrouge.py /^ def evaluate(self, system_id=1, rouge_args=None):$/;" m class:Rouge155 +ext_batch_size_fn models/data_loader.py /^def ext_batch_size_fn(new, count):$/;" f +fget others/pyrouge.py /^ def fget(self):$/;" f function:Rouge155.__create_dir_property +filter_shard_state models/loss.py /^def filter_shard_state(state, shard_size=None):$/;" f +format_to_bert prepro/data_builder.py /^def format_to_bert(args):$/;" f +format_to_lines prepro/data_builder.py /^def format_to_lines(args):$/;" f +format_xsum_to_lines prepro/data_builder.py /^def format_xsum_to_lines(args):$/;" f +forward models/decoder.py /^ def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,$/;" m class:TransformerDecoderLayer +forward models/decoder.py /^ def forward(self, tgt, memory_bank, state, memory_lengths=None,$/;" m class:TransformerDecoder +forward models/encoder.py /^ def forward(self, emb, step=None):$/;" m class:PositionalEncoding +forward models/encoder.py /^ def forward(self, iter, query, inputs, mask):$/;" m class:TransformerEncoderLayer +forward models/encoder.py /^ def forward(self, top_vecs, mask):$/;" m class:ExtTransformerEncoder +forward models/encoder.py /^ def forward(self, x, mask_cls):$/;" m class:Classifier +forward models/loss.py /^ def forward(self, output, target):$/;" m class:LabelSmoothingLoss +forward models/model_builder.py /^ def forward(self, src, segs, clss, mask_src, mask_cls):$/;" m class:ExtSummarizer +forward models/model_builder.py /^ def forward(self, src, tgt, segs, clss, mask_src, mask_tgt, mask_cls):$/;" m class:AbsSummarizer +forward models/model_builder.py /^ def forward(self, x, segs, mask):$/;" m class:Bert +forward models/neural.py /^ def forward(self, key, value, query, mask=None,$/;" m class:MultiHeadedAttention +forward models/neural.py /^ def forward(self, source, memory_bank, memory_lengths=None, memory_masks=None):$/;" m class:GlobalAttention +forward models/neural.py /^ def forward(self, x):$/;" m class:PositionwiseFeedForward +from_batch models/predictor.py /^ def from_batch(self, translation_batch):$/;" m class:Translator +from_pretrained others/tokenization.py /^ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):$/;" m class:BertTokenizer +fset others/pyrouge.py /^ def fset(self, path):$/;" f function:Rouge155.__create_dir_property +gc models/data_loader.py /^import gc$/;" i +gc prepro/data_builder.py /^import gc$/;" i +gelu models/neural.py /^def gelu(x):$/;" f +get_current_origin translate/beam.py /^ def get_current_origin(self):$/;" m class:Beam +get_current_state translate/beam.py /^ def get_current_state(self):$/;" m class:Beam +get_emb models/encoder.py /^ def get_emb(self, emb):$/;" m class:PositionalEncoding +get_generator models/model_builder.py /^def get_generator(vocab_size, dec_hidden_size, device):$/;" f +get_hyp translate/beam.py /^ def get_hyp(self, timestep, k):$/;" m class:Beam +get_rank models/reporter.py /^ from torch.distributed import get_rank$/;" i +get_rank models/reporter_ext.py /^ from torch.distributed import get_rank$/;" i +glob models/data_loader.py /^import glob$/;" i +glob prepro/data_builder.py /^import glob$/;" i +glob train_abstractive.py /^import glob$/;" i +glob train_extractive.py /^import glob$/;" i +greedy_selection prepro/data_builder.py /^def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):$/;" f +has_repeat post_stats.py /^def has_repeat(elements):$/;" f +hashhex prepro/data_builder.py /^def hashhex(s):$/;" f +hashlib prepro/data_builder.py /^import hashlib$/;" i +init_decoder_state models/decoder.py /^ def init_decoder_state(self, src, memory_bank,$/;" m class:TransformerDecoder +init_logger others/logging.py /^def init_logger(log_file=None, log_file_level=logging.NOTSET):$/;" f +init_logger preprocess.py /^from others.logging import init_logger$/;" i +init_logger train.py /^from others.logging import init_logger$/;" i +init_logger train_abstractive.py /^from others.logging import logger, init_logger$/;" i +init_logger train_extractive.py /^from others.logging import logger, init_logger$/;" i +is_master distributed.py /^def is_master(gpu_ranks, device_id):$/;" f +itertools prepro/data_builder.py /^import itertools$/;" i +json prepro/data_builder.py /^import json$/;" i +length_average translate/penalties.py /^ def length_average(self, beam, logprobs, alpha=0.):$/;" m class:PenaltyBuilder +length_none translate/penalties.py /^ def length_none(self, beam, logprobs, alpha=0., beta=0.):$/;" m class:PenaltyBuilder +length_penalty translate/penalties.py /^ def length_penalty(self):$/;" m class:PenaltyBuilder +length_wu translate/penalties.py /^ def length_wu(self, beam, logprobs, alpha=0.):$/;" m class:PenaltyBuilder +load_dataset models/data_loader.py /^def load_dataset(args, corpus_type, shuffle):$/;" f +load_dataset train_abstractive.py /^from models.data_loader import load_dataset$/;" i +load_dataset train_extractive.py /^from models.data_loader import load_dataset$/;" i +load_json prepro/data_builder.py /^def load_json(p, lower):$/;" f +load_state_dict models/optimizers.py /^ def load_state_dict(self, state_dicts):$/;" m class:MultipleOptimizer +load_text models/data_loader.py /^def load_text(args, source_fp, target_fp, device):$/;" f +load_vocab others/tokenization.py /^def load_vocab(vocab_file):$/;" f +load_xml prepro/data_builder.py /^def load_xml(p):$/;" f +log models/predictor.py /^ def log(self, sent_number):$/;" m class:Translation +log models/reporter.py /^ def log(self, *args, **kwargs):$/;" m class:ReportMgrBase +log models/reporter_ext.py /^ def log(self, *args, **kwargs):$/;" m class:ReportMgrBase +log others/pyrouge.py /^from pyrouge.utils import log$/;" i +log_tensorboard models/reporter.py /^ def log_tensorboard(self, prefix, writer, learning_rate, step):$/;" m class:Statistics +log_tensorboard models/reporter_ext.py /^ def log_tensorboard(self, prefix, writer, learning_rate, step):$/;" m class:Statistics +logger distributed.py /^from others.logging import logger$/;" i +logger models/data_loader.py /^from others.logging import logger$/;" i +logger models/reporter.py /^from others.logging import logger$/;" i +logger models/reporter_ext.py /^from others.logging import logger$/;" i +logger models/trainer.py /^from others.logging import logger$/;" i +logger models/trainer_ext.py /^from others.logging import logger$/;" i +logger others/logging.py /^logger = logging.getLogger()$/;" v +logger others/tokenization.py /^logger = logging.getLogger(__name__)$/;" v +logger prepro/data_builder.py /^from others.logging import logger$/;" i +logger train_abstractive.py /^from others.logging import logger, init_logger$/;" i +logger train_extractive.py /^from others.logging import logger, init_logger$/;" i +logging others/logging.py /^import logging$/;" i +logging others/tokenization.py /^import logging$/;" i +map_batch_fn models/decoder.py /^ def map_batch_fn(self, fn):$/;" m class:TransformerDecoderState +map_batch_fn models/neural.py /^ def map_batch_fn(self, fn):$/;" m class:DecoderState +math distributed.py /^import math$/;" i +math models/adam.py /^import math$/;" i +math models/encoder.py /^import math$/;" i +math models/neural.py /^import math$/;" i +math models/predictor.py /^import math$/;" i +math models/reporter.py /^import math$/;" i +maybe_log_tensorboard models/reporter.py /^ def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):$/;" m class:ReportMgr +maybe_log_tensorboard models/reporter_ext.py /^ def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):$/;" m class:ReportMgr +mkdtemp others/pyrouge.py /^from tempfile import mkdtemp$/;" i +model_builder train_abstractive.py /^from models import data_loader, model_builder$/;" i +model_builder train_extractive.py /^from models import data_loader, model_builder$/;" i +model_filename_pattern others/pyrouge.py /^ def model_filename_pattern(self):$/;" m class:Rouge155 +model_filename_pattern others/pyrouge.py /^ def model_filename_pattern(self, pattern):$/;" m class:Rouge155 +model_flags train.py /^model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',$/;" v +model_flags train_abstractive.py /^model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size',$/;" v +model_flags train_extractive.py /^model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', 'rnn_size']$/;" v +monolithic_compute_loss models/loss.py /^ def monolithic_compute_loss(self, batch, output):$/;" m class:LossComputeBase +multi_init distributed.py /^def multi_init(device_id, world_size,gpu_ranks):$/;" f +n_grams post_stats.py /^def n_grams(tokens, n):$/;" f +nn models/decoder.py /^import torch.nn as nn$/;" i +nn models/encoder.py /^import torch.nn as nn$/;" i +nn models/loss.py /^import torch.nn as nn$/;" i +nn models/loss.py /^import torch.nn.functional as F$/;" i +nn models/model_builder.py /^import torch.nn as nn$/;" i +nn models/neural.py /^import torch.nn as nn$/;" i +nn models/neural.py /^import torch.nn.functional as F$/;" i +np models/decoder.py /^import numpy as np$/;" i +np models/trainer.py /^import numpy as np$/;" i +np models/trainer_ext.py /^import numpy as np$/;" i +nyt_remove_words prepro/data_builder.py /^nyt_remove_words = ["photo", "graph", "chart", "map", "table", "drawing"]$/;" v +open others/tokenization.py /^from io import open$/;" i +optim models/optimizers.py /^import torch.optim as optim$/;" i +os cal_rouge.py /^import os$/;" i +os models/predictor.py /^import os$/;" i +os models/trainer.py /^import os$/;" i +os models/trainer_ext.py /^import os$/;" i +os others/pyrouge.py /^import os$/;" i +os others/tokenization.py /^import os$/;" i +os others/utils.py /^import os$/;" i +os prepro/data_builder.py /^import os$/;" i +os train.py /^import os$/;" i +os train_abstractive.py /^import os$/;" i +os train_extractive.py /^import os$/;" i +output models/reporter.py /^ def output(self, step, num_steps, learning_rate, start):$/;" m class:Statistics +output models/reporter_ext.py /^ def output(self, step, num_steps, learning_rate, start):$/;" m class:Statistics +output_to_dict others/pyrouge.py /^ def output_to_dict(self, output):$/;" m class:Rouge155 +parser others/pyrouge.py /^ parser = argparse.ArgumentParser(parents=[rouge_path_parser])$/;" v class:Rouge155 +parser post_stats.py /^ parser = argparse.ArgumentParser()$/;" v +partial others/pyrouge.py /^from functools import partial$/;" i +path post_stats.py /^from os import path$/;" i +penalties translate/beam.py /^from translate import penalties$/;" i +pickle distributed.py /^import pickle$/;" i +pjoin prepro/data_builder.py /^from os.path import join as pjoin$/;" i +platform others/pyrouge.py /^import platform$/;" i +ppl models/reporter.py /^ def ppl(self):$/;" m class:Statistics +preprocess models/data_loader.py /^ def preprocess(self, ex, is_test):$/;" m class:DataIterator +preprocess prepro/data_builder.py /^ def preprocess(self, src, tgt, sent_labels, use_bert_basic_tokenizer=False, is_test=False):$/;" m class:BertData +print_function distributed.py /^from __future__ import print_function$/;" i +print_function models/predictor.py /^from __future__ import print_function$/;" i +print_function models/reporter.py /^from __future__ import print_function$/;" i +print_function models/reporter_ext.py /^from __future__ import print_function$/;" i +print_function others/pyrouge.py /^from __future__ import print_function, unicode_literals, division$/;" i +print_function others/tokenization.py /^from __future__ import absolute_import, division, print_function, unicode_literals$/;" i +process cal_rouge.py /^def process(data):$/;" f +process others/pyrouge.py /^ def process(input_dir, output_dir, function):$/;" m class:DirectoryProcessor +process others/utils.py /^def process(params):$/;" f +pyrouge cal_rouge.py /^from others import pyrouge$/;" i +pyrouge others/utils.py /^from others import pyrouge$/;" i +random models/data_loader.py /^import random$/;" i +random prepro/data_builder.py /^import random$/;" i +random train_abstractive.py /^import random$/;" i +random train_extractive.py /^import random$/;" i +re others/pyrouge.py /^import re$/;" i +re others/utils.py /^import re$/;" i +re post_stats.py /^import re$/;" i +re prepro/data_builder.py /^import re$/;" i +recover_from_corenlp prepro/data_builder.py /^def recover_from_corenlp(s):$/;" f +reduce post_stats.py /^from functools import reduce$/;" i +repeat_beam_size_times models/decoder.py /^ def repeat_beam_size_times(self, beam_size):$/;" m class:TransformerDecoderState +report_step models/reporter.py /^ def report_step(self, lr, step, train_stats=None, valid_stats=None):$/;" m class:ReportMgrBase +report_step models/reporter_ext.py /^ def report_step(self, lr, step, train_stats=None, valid_stats=None):$/;" m class:ReportMgrBase +report_training models/reporter.py /^ def report_training(self, step, num_steps, learning_rate,$/;" m class:ReportMgrBase +report_training models/reporter_ext.py /^ def report_training(self, step, num_steps, learning_rate,$/;" m class:ReportMgrBase +rouge others/pyrouge.py /^ rouge = Rouge155(args.rouge_home)$/;" v class:Rouge155 +rouge_path_parser others/pyrouge.py /^ from utils.argparsers import rouge_path_parser$/;" i +rouge_results_to_str cal_rouge.py /^def rouge_results_to_str(results_dict):$/;" f +rouge_results_to_str models/predictor.py /^from others.utils import rouge_results_to_str, test_rouge, tile$/;" i +rouge_results_to_str models/trainer.py /^from others.utils import test_rouge, rouge_results_to_str$/;" i +rouge_results_to_str models/trainer_ext.py /^from others.utils import test_rouge, rouge_results_to_str$/;" i +rouge_results_to_str others/utils.py /^def rouge_results_to_str(results_dict):$/;" f +run train_abstractive.py /^def run(args, device_id, error_queue):$/;" f +run train_extractive.py /^def run(args, device_id, error_queue):$/;" f +save_home_dir others/pyrouge.py /^ def save_home_dir(self):$/;" m class:Rouge155 +score models/neural.py /^ def score(self, h_t, h_s):$/;" m class:GlobalAttention +score translate/beam.py /^ def score(self, beam, logprobs):$/;" m class:GNMTGlobalScorer +sequence_mask models/neural.py /^def sequence_mask(lengths, max_len=None):$/;" f +set_parameters models/optimizers.py /^ def set_parameters(self, params):$/;" m class:Optimizer +settings_file others/pyrouge.py /^ def settings_file(self):$/;" m class:Rouge155 +shape models/neural.py /^ def shape(x):$/;" f function:MultiHeadedAttention.forward +sharded_compute_loss models/loss.py /^ def sharded_compute_loss(self, batch, output,$/;" m class:LossComputeBase +shards models/loss.py /^def shards(state, shard_size, eval_only=False):$/;" f +shutil cal_rouge.py /^import shutil$/;" i +shutil others/utils.py /^import shutil$/;" i +signal train_abstractive.py /^ import signal$/;" i +signal train_abstractive.py /^import signal$/;" i +signal train_extractive.py /^ import signal$/;" i +signal train_extractive.py /^import signal$/;" i +signal_handler train_abstractive.py /^ def signal_handler(self, signalnum, stackframe):$/;" m class:ErrorHandler +signal_handler train_extractive.py /^ def signal_handler(self, signalnum, stackframe):$/;" m class:ErrorHandler +sort_finished translate/beam.py /^ def sort_finished(self, minimum=None):$/;" m class:Beam +split_sentences others/pyrouge.py /^ def split_sentences(self):$/;" m class:Rouge155 +start models/reporter.py /^ def start(self):$/;" m class:ReportMgrBase +start models/reporter_ext.py /^ def start(self):$/;" m class:ReportMgrBase +state models/optimizers.py /^ def state(self):$/;" m class:MultipleOptimizer +state_dict models/optimizers.py /^ def state_dict(self):$/;" m class:MultipleOptimizer +step models/adam.py /^ def step(self, closure=None):$/;" m class:Adam +step models/optimizers.py /^ def step(self):$/;" m class:MultipleOptimizer +step models/optimizers.py /^ def step(self):$/;" m class:Optimizer +str2bool post_stats.py /^def str2bool(v):$/;" f +str2bool preprocess.py /^def str2bool(v):$/;" f +str2bool train.py /^def str2bool(v):$/;" f +str2bool train_abstractive.py /^def str2bool(v):$/;" f +subprocess prepro/data_builder.py /^import subprocess$/;" i +sys cal_rouge.py /^import sys$/;" i +sys models/reporter.py /^import sys$/;" i +sys models/reporter_ext.py /^import sys$/;" i +system_filename_pattern others/pyrouge.py /^ def system_filename_pattern(self):$/;" m class:Rouge155 +system_filename_pattern others/pyrouge.py /^ def system_filename_pattern(self, pattern):$/;" m class:Rouge155 +test models/trainer.py /^ def test(self, test_iter, step, cal_lead=False, cal_oracle=False):$/;" m class:Trainer +test models/trainer_ext.py /^ def test(self, test_iter, step, cal_lead=False, cal_oracle=False):$/;" m class:Trainer +test_abs train.py /^from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs$/;" i +test_abs train_abstractive.py /^def test_abs(args, device_id, pt, step):$/;" f +test_ext train.py /^from train_extractive import train_ext, validate_ext, test_ext, test_text_ext$/;" i +test_ext train_extractive.py /^def test_ext(args, device_id, pt, step):$/;" f +test_rouge cal_rouge.py /^def test_rouge(cand, ref,num_processes):$/;" f +test_rouge models/predictor.py /^from others.utils import rouge_results_to_str, test_rouge, tile$/;" i +test_rouge models/trainer.py /^from others.utils import test_rouge, rouge_results_to_str$/;" i +test_rouge models/trainer_ext.py /^from others.utils import test_rouge, rouge_results_to_str$/;" i +test_rouge others/utils.py /^def test_rouge(temp_dir, cand, ref):$/;" f +test_text_abs train.py /^from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs$/;" i +test_text_abs train_abstractive.py /^def test_text_abs(args):$/;" f +test_text_ext train.py /^from train_extractive import train_ext, validate_ext, test_ext, test_text_ext$/;" i +test_text_ext train_extractive.py /^def test_text_ext(args):$/;" f +threading train_abstractive.py /^ import threading$/;" i +threading train_extractive.py /^ import threading$/;" i +tile models/predictor.py /^from others.utils import rouge_results_to_str, test_rouge, tile$/;" i +tile others/utils.py /^def tile(x, count, dim=0):$/;" f +time cal_rouge.py /^import time$/;" i +time models/reporter.py /^import time$/;" i +time models/reporter_ext.py /^import time$/;" i +time others/utils.py /^import time$/;" i +time preprocess.py /^import time$/;" i +time train_abstractive.py /^import time$/;" i +time train_extractive.py /^import time$/;" i +tokenize others/tokenization.py /^ def tokenize(self, text):$/;" m class:BasicTokenizer +tokenize others/tokenization.py /^ def tokenize(self, text):$/;" m class:WordpieceTokenizer +tokenize others/tokenization.py /^ def tokenize(self, text, use_bert_basic_tokenizer=False):$/;" m class:BertTokenizer +tokenize prepro/data_builder.py /^def tokenize(args):$/;" f +torch distributed.py /^import torch.distributed$/;" i +torch models/adam.py /^import torch$/;" i +torch models/data_loader.py /^import torch$/;" i +torch models/decoder.py /^import torch$/;" i +torch models/decoder.py /^import torch.nn as nn$/;" i +torch models/encoder.py /^import torch$/;" i +torch models/encoder.py /^import torch.nn as nn$/;" i +torch models/loss.py /^import torch$/;" i +torch models/loss.py /^import torch.nn as nn$/;" i +torch models/loss.py /^import torch.nn.functional as F$/;" i +torch models/model_builder.py /^import torch$/;" i +torch models/model_builder.py /^import torch.nn as nn$/;" i +torch models/neural.py /^import torch$/;" i +torch models/neural.py /^import torch.nn as nn$/;" i +torch models/neural.py /^import torch.nn.functional as F$/;" i +torch models/optimizers.py /^import torch$/;" i +torch models/optimizers.py /^import torch.optim as optim$/;" i +torch models/predictor.py /^import torch$/;" i +torch models/trainer.py /^import torch$/;" i +torch models/trainer_ext.py /^import torch$/;" i +torch prepro/data_builder.py /^import torch$/;" i +torch train_abstractive.py /^import torch$/;" i +torch train_extractive.py /^import torch$/;" i +torch translate/beam.py /^import torch$/;" i +torch translate/penalties.py /^import torch$/;" i +tqdm models/data_loader.py /^from tqdm import tqdm$/;" i +traceback train_abstractive.py /^ import traceback$/;" i +traceback train_extractive.py /^ import traceback$/;" i +train models/trainer.py /^ def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):$/;" m class:Trainer +train models/trainer_ext.py /^ def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):$/;" m class:Trainer +train_abs train.py /^from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs$/;" i +train_abs train_abstractive.py /^def train_abs(args, device_id):$/;" f +train_abs_multi train_abstractive.py /^def train_abs_multi(args):$/;" f +train_abs_single train_abstractive.py /^def train_abs_single(args, device_id):$/;" f +train_ext train.py /^from train_extractive import train_ext, validate_ext, test_ext, test_text_ext$/;" i +train_ext train_extractive.py /^def train_ext(args, device_id):$/;" f +train_iter_fct train_abstractive.py /^ def train_iter_fct():$/;" f function:train_abs_single +train_iter_fct train_extractive.py /^ def train_iter_fct():$/;" f function:train_single_ext +train_multi_ext train_extractive.py /^def train_multi_ext(args):$/;" f +train_single_ext train_extractive.py /^def train_single_ext(args, device_id):$/;" f +translate models/predictor.py /^ def translate(self,$/;" m class:Translator +translate_batch models/predictor.py /^ def translate_batch(self, batch, fast=False):$/;" m class:Translator +unicode_literals others/pyrouge.py /^from __future__ import print_function, unicode_literals, division$/;" i +unicode_literals others/tokenization.py /^from __future__ import absolute_import, division, print_function, unicode_literals$/;" i +unicodedata others/tokenization.py /^import unicodedata$/;" i +unshape models/neural.py /^ def unshape(x):$/;" f function:MultiHeadedAttention.forward +update models/reporter.py /^ def update(self, stat, update_n_src_words=False):$/;" m class:Statistics +update models/reporter_ext.py /^ def update(self, stat, update_n_src_words=False):$/;" m class:Statistics +update_state models/decoder.py /^ def update_state(self, new_input, previous_layer_inputs):$/;" m class:TransformerDecoderState +use_gpu models/optimizers.py /^def use_gpu(opt):$/;" f +validate models/trainer.py /^ def validate(self, valid_iter, step=0):$/;" m class:Trainer +validate models/trainer_ext.py /^ def validate(self, valid_iter, step=0):$/;" m class:Trainer +validate train_abstractive.py /^def validate(args, device_id, pt, step):$/;" f +validate train_extractive.py /^def validate(args, device_id, pt, step):$/;" f +validate_abs train.py /^from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs$/;" i +validate_abs train_abstractive.py /^def validate_abs(args, device_id):$/;" f +validate_ext train.py /^from train_extractive import train_ext, validate_ext, test_ext, test_text_ext$/;" i +validate_ext train_extractive.py /^def validate_ext(args, device_id):$/;" f +verify_dir others/pyrouge.py /^from pyrouge.utils.file_utils import verify_dir$/;" i +whitespace_tokenize others/tokenization.py /^def whitespace_tokenize(text):$/;" f +write_config others/pyrouge.py /^ def write_config(self, config_file_path=None, system_id=None):$/;" m class:Rouge155 +write_config_static others/pyrouge.py /^ def write_config_static(system_dir, system_filename_pattern,$/;" m class:Rouge155 +xavier_uniform_ models/model_builder.py /^from torch.nn.init import xavier_uniform_$/;" i +xent models/reporter.py /^ def xent(self):$/;" m class:Statistics +xent models/reporter_ext.py /^ def xent(self):$/;" m class:Statistics +xml prepro/data_builder.py /^import xml.etree.ElementTree as ET$/;" i +zero_grad models/optimizers.py /^ def zero_grad(self):$/;" m class:MultipleOptimizer diff --git a/src/train.py b/src/train.py index f32d5a0c..15f46606 100644 --- a/src/train.py +++ b/src/train.py @@ -8,12 +8,11 @@ import os from others.logging import init_logger from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs -from train_extractive import train_ext, validate_ext, test_ext +from train_extractive import train_ext, validate_ext, test_ext, test_text_ext model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size', 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval'] - def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True @@ -23,90 +22,88 @@ def str2bool(v): raise argparse.ArgumentTypeError('Boolean value expected.') - - if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("-task", default='ext', type=str, choices=['ext', 'abs']) - parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'baseline']) - parser.add_argument("-mode", default='train', type=str, choices=['train', 'validate', 'test']) - parser.add_argument("-bert_data_path", default='../bert_data_new/cnndm') - parser.add_argument("-model_path", default='../models/') - parser.add_argument("-result_path", default='../results/cnndm') - parser.add_argument("-temp_dir", default='../temp') - - parser.add_argument("-batch_size", default=140, type=int) - parser.add_argument("-test_batch_size", default=200, type=int) - - parser.add_argument("-max_pos", default=512, type=int) - parser.add_argument("-use_interval", type=str2bool, nargs='?',const=True,default=True) - parser.add_argument("-large", type=str2bool, nargs='?',const=True,default=False) - parser.add_argument("-load_from_extractive", default='', type=str) - - parser.add_argument("-sep_optim", type=str2bool, nargs='?',const=True,default=False) - parser.add_argument("-lr_bert", default=2e-3, type=float) - parser.add_argument("-lr_dec", default=2e-3, type=float) - parser.add_argument("-use_bert_emb", type=str2bool, nargs='?',const=True,default=False) - - parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("-dec_dropout", default=0.2, type=float) - parser.add_argument("-dec_layers", default=6, type=int) - parser.add_argument("-dec_hidden_size", default=768, type=int) - parser.add_argument("-dec_heads", default=8, type=int) - parser.add_argument("-dec_ff_size", default=2048, type=int) - parser.add_argument("-enc_hidden_size", default=512, type=int) - parser.add_argument("-enc_ff_size", default=512, type=int) - parser.add_argument("-enc_dropout", default=0.2, type=float) - parser.add_argument("-enc_layers", default=6, type=int) + parser.add_argument("--task", default='ext', type=str, choices=['ext', 'abs']) + parser.add_argument("--encoder", default='bert', type=str, choices=['bert', 'baseline']) + parser.add_argument("--mode", default='train', type=str, choices=['train', 'validate', 'test', 'test_text']) + parser.add_argument("--bert_data_path", default='../bert_data_new/cnndm') + parser.add_argument("--model_path", default='../models/') + parser.add_argument("--result_path", default='../results/cnndm') + parser.add_argument("--temp_dir", default='../temp') + parser.add_argument("--text_src", default='') + parser.add_argument("--text_tgt", default='') + + parser.add_argument("--batch_size", default=140, type=int) + parser.add_argument("--test_batch_size", default=200, type=int) + parser.add_argument("--max_ndocs_in_batch", default=6, type=int) + + parser.add_argument("--max_pos", default=512, type=int) + parser.add_argument("--use_interval", type=str2bool, nargs='?',const=True,default=True) + parser.add_argument("--large", type=str2bool, nargs='?',const=True,default=False) + parser.add_argument("--load_from_extractive", default='', type=str) + + parser.add_argument("--sep_optim", type=str2bool, nargs='?',const=True,default=False) + parser.add_argument("--lr_bert", default=2e-3, type=float) + parser.add_argument("--lr_dec", default=2e-3, type=float) + parser.add_argument("--use_bert_emb", type=str2bool, nargs='?',const=True,default=False) + + parser.add_argument("--share_emb", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--finetune_bert", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--dec_dropout", default=0.2, type=float) + parser.add_argument("--dec_layers", default=6, type=int) + parser.add_argument("--dec_hidden_size", default=768, type=int) + parser.add_argument("--dec_heads", default=8, type=int) + parser.add_argument("--dec_ff_size", default=2048, type=int) + parser.add_argument("--enc_hidden_size", default=512, type=int) + parser.add_argument("--enc_ff_size", default=512, type=int) + parser.add_argument("--enc_dropout", default=0.2, type=float) + parser.add_argument("--enc_layers", default=6, type=int) # params for EXT - parser.add_argument("-ext_dropout", default=0.2, type=float) - parser.add_argument("-ext_layers", default=2, type=int) - parser.add_argument("-ext_hidden_size", default=768, type=int) - parser.add_argument("-ext_heads", default=8, type=int) - parser.add_argument("-ext_ff_size", default=2048, type=int) - - parser.add_argument("-label_smoothing", default=0.1, type=float) - parser.add_argument("-generator_shard_size", default=32, type=int) - parser.add_argument("-alpha", default=0.6, type=float) - parser.add_argument("-beam_size", default=5, type=int) - parser.add_argument("-min_length", default=15, type=int) - parser.add_argument("-max_length", default=150, type=int) - parser.add_argument("-max_tgt_len", default=140, type=int) - - - - parser.add_argument("-param_init", default=0, type=float) - parser.add_argument("-param_init_glorot", type=str2bool, nargs='?',const=True,default=True) - parser.add_argument("-optim", default='adam', type=str) - parser.add_argument("-lr", default=1, type=float) - parser.add_argument("-beta1", default= 0.9, type=float) - parser.add_argument("-beta2", default=0.999, type=float) - parser.add_argument("-warmup_steps", default=8000, type=int) - parser.add_argument("-warmup_steps_bert", default=8000, type=int) - parser.add_argument("-warmup_steps_dec", default=8000, type=int) - parser.add_argument("-max_grad_norm", default=0, type=float) - - parser.add_argument("-save_checkpoint_steps", default=5, type=int) - parser.add_argument("-accum_count", default=1, type=int) - parser.add_argument("-report_every", default=1, type=int) - parser.add_argument("-train_steps", default=1000, type=int) - parser.add_argument("-recall_eval", type=str2bool, nargs='?',const=True,default=False) - - - parser.add_argument('-visible_gpus', default='-1', type=str) - parser.add_argument('-gpu_ranks', default='0', type=str) - parser.add_argument('-log_file', default='../logs/cnndm.log') - parser.add_argument('-seed', default=666, type=int) - - parser.add_argument("-test_all", type=str2bool, nargs='?',const=True,default=False) - parser.add_argument("-test_from", default='') - parser.add_argument("-test_start_from", default=-1, type=int) - - parser.add_argument("-train_from", default='') - parser.add_argument("-report_rouge", type=str2bool, nargs='?',const=True,default=True) - parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--ext_dropout", default=0.2, type=float) + parser.add_argument("--ext_layers", default=2, type=int) + parser.add_argument("--ext_hidden_size", default=768, type=int) + parser.add_argument("--ext_heads", default=8, type=int) + parser.add_argument("--ext_ff_size", default=2048, type=int) + + parser.add_argument("--label_smoothing", default=0.1, type=float) + parser.add_argument("--generator_shard_size", default=32, type=int) + parser.add_argument("--alpha", default=0.6, type=float) + parser.add_argument("--beam_size", default=5, type=int) + parser.add_argument("--min_length", default=15, type=int) + parser.add_argument("--max_length", default=150, type=int) + parser.add_argument("--max_tgt_len", default=140, type=int) + + parser.add_argument("--param_init", default=0, type=float) + parser.add_argument("--param_init_glorot", type=str2bool, nargs='?',const=True,default=True) + parser.add_argument("--optim", default='adam', type=str) + parser.add_argument("--lr", default=1, type=float) + parser.add_argument("--beta1", default= 0.9, type=float) + parser.add_argument("--beta2", default=0.999, type=float) + parser.add_argument("--warmup_steps", default=8000, type=int) + parser.add_argument("--warmup_steps_bert", default=8000, type=int) + parser.add_argument("--warmup_steps_dec", default=8000, type=int) + parser.add_argument("--max_grad_norm", default=0, type=float) + + parser.add_argument("--save_checkpoint_steps", default=5, type=int) + parser.add_argument("--accum_count", default=1, type=int) + parser.add_argument("--report_every", default=1, type=int) + parser.add_argument("--train_steps", default=1000, type=int) + parser.add_argument("--recall_eval", type=str2bool, nargs='?',const=True,default=False) + + parser.add_argument('--visible_gpus', default='-1', type=str) + parser.add_argument('--gpu_ranks', default='0', type=str) + parser.add_argument('--log_file', default='../logs/cnndm.log') + parser.add_argument('--seed', default=666, type=int) + + parser.add_argument("--test_all", type=str2bool, nargs='?',const=True,default=False) + parser.add_argument("--test_from", default='') + parser.add_argument("--test_start_from", default=-1, type=int) + + parser.add_argument("--train_from", default='') + parser.add_argument("--report_rouge", type=str2bool, nargs='?',const=True,default=True) + parser.add_argument("--block_trigram", type=str2bool, nargs='?', const=True, default=True) args = parser.parse_args() args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))] @@ -134,12 +131,7 @@ def str2bool(v): step = 0 test_abs(args, device_id, cp, step) elif (args.mode == 'test_text'): - cp = args.test_from - try: - step = int(cp.split('.')[-2].split('_')[-1]) - except: - step = 0 - test_text_abs(args, device_id, cp, step) + test_text_abs(args) elif (args.task == 'ext'): if (args.mode == 'train'): @@ -154,9 +146,4 @@ def str2bool(v): step = 0 test_ext(args, device_id, cp, step) elif (args.mode == 'test_text'): - cp = args.test_from - try: - step = int(cp.split('.')[-2].split('_')[-1]) - except: - step = 0 - test_text_abs(args, device_id, cp, step) + test_text_ext(args) diff --git a/src/train_abstractive.py b/src/train_abstractive.py index 545efde3..fd16ac45 100644 --- a/src/train_abstractive.py +++ b/src/train_abstractive.py @@ -225,33 +225,6 @@ def test_abs(args, device_id, pt, step): predictor.translate(test_iter, step) -def test_text_abs(args, device_id, pt, step): - device = "cpu" if args.visible_gpus == '-1' else "cuda" - if (pt != ''): - test_from = pt - else: - test_from = args.test_from - logger.info('Loading checkpoint from %s' % test_from) - - checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) - opt = vars(checkpoint['opt']) - for k in opt.keys(): - if (k in model_flags): - setattr(args, k, opt[k]) - print(args) - - model = AbsSummarizer(args, device, checkpoint) - model.eval() - - test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), - args.test_batch_size, device, - shuffle=False, is_test=True) - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) - symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], - 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} - predictor = build_predictor(args, tokenizer, symbols, model, logger) - predictor.translate(test_iter, step) - def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), @@ -332,3 +305,29 @@ def train_iter_fct(): trainer = build_trainer(args, device_id, model, optim, train_loss) trainer.train(train_iter_fct, args.train_steps) + + + + +def test_text_abs(args): + + logger.info('Loading checkpoint from %s' % args.test_from) + device = "cpu" if args.visible_gpus == '-1' else "cuda" + + checkpoint = torch.load(args.test_from, map_location=lambda storage, loc: storage) + opt = vars(checkpoint['opt']) + for k in opt.keys(): + if (k in model_flags): + setattr(args, k, opt[k]) + print(args) + + model = AbsSummarizer(args, device, checkpoint) + model.eval() + + test_iter = data_loader.load_text(args, args.text_src, args.text_tgt, device) + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) + symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], + 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} + predictor = build_predictor(args, tokenizer, symbols, model, logger) + predictor.translate(test_iter, -1) diff --git a/src/train_extractive.py b/src/train_extractive.py index 26cb6668..e6e9a4be 100644 --- a/src/train_extractive.py +++ b/src/train_extractive.py @@ -243,3 +243,23 @@ def train_iter_fct(): trainer = build_trainer(args, device_id, model, optim) trainer.train(train_iter_fct, args.train_steps) + + +def test_text_ext(args): + logger.info('Loading checkpoint from %s' % args.test_from) + checkpoint = torch.load(args.test_from, map_location=lambda storage, loc: storage) + opt = vars(checkpoint['opt']) + for k in opt.keys(): + if (k in model_flags): + setattr(args, k, opt[k]) + print(args) + device = "cpu" if args.visible_gpus == '-1' else "cuda" + device_id = 0 if device == "cuda" else -1 + + model = ExtSummarizer(args, device, checkpoint) + model.eval() + + test_iter = data_loader.load_text(args, args.text_src, args.text_tgt, device) + + trainer = build_trainer(args, device_id, model, None) + trainer.test(test_iter, -1)