diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md b/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md new file mode 100644 index 00000000..e9dff202 --- /dev/null +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md @@ -0,0 +1 @@ +基于预训练BERT的古诗生成器 diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb new file mode 100644 index 00000000..c103b616 --- /dev/null +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -0,0 +1,1018 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 用BERT实现自动写诗\n", + "\n", + "**作者**:[fiyen](https://github.com/fiyen)\n", + "\n", + "**日期**:2021.06\n", + "\n", + "**摘要**:本示例教程将会演示如何使用飞桨2.0以及PaddleNLP快速实现用BERT预训练模型生成高质量诗歌。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 摘要\n", + "古诗,中华民族最高贵的文化瑰宝,在几千年文化传承中扮演着重要的角色。诗歌已经融入中华儿女的血脉之中,上到古稀之人,下到刚入学的孩童,都能随口吟诵一首诗出来。诗句的运用体现了古今诗人对文字运用的娴熟技艺,同时寄托着诗人深远的情思。诗句或优美或刚劲,或温婉或苍凉,让人在阅读诗歌的时候,如沐春风,身临其境。\n", + "\n", + "美好的诗歌让人心向往之,当我们的眼球接受了美好景物时,谁不曾有“此情此景,我想吟诗一首”的冲动,却限于实力张口息声,半晌想不出一个合适的表达。此时,如果我们有一个强大的诗歌生成工具,岂不美哉?\n", + "\n", + "没问题,通过飞桨,搭建一个古诗自动生成模型将不再是一个困难的事情。在这里,我们将展示如何用飞桨快速搭建一个强大的古诗生成模型。\n", + "\n", + "在这个示例中,我们将快速构建基于BERT预训练模型的古诗生成器,支持诗歌风格定制,以及生成藏头诗。模型基于飞桨2.0框架,BERT预训练模型则调用自PaddleNLP,诗歌数据集采用Github开源数据集。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 1. 相关内容介绍\n", + "\n", + "### 1.1 PaddleNLP\n", + "\n", + "官网链接:[https://github.com/PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)\n", + "\n", + "![](https://github.com/fiyen/models/raw/release/2.0-beta/PaddleNLP/docs/imgs/paddlenlp.png)\n", + "\n", + "\n", + "PaddleNLP旨在帮助开发者提高文本建模的效率,通过丰富的模型库、简洁易用的API,提供飞桨2.0的最佳实践并加速NLP领域应用产业落地效率。其产品特性如下:\n", + "\n", + "* **丰富的模型库**\n", + "\n", + "涵盖了NLP主流应用相关的前沿模型,包括中文词向量、预训练模型、词法分析、文本分类、文本匹配、文本生成、机器翻译、通用对话、问答系统等。\n", + "\n", + "* **简洁易用的API**\n", + "\n", + "深度兼容飞桨2.0的高层API体系,提供更多可复用的文本建模模块,可大幅度减少数据处理、组网、训练环节的代码开发,提高开发效率。\n", + "\n", + "* **高性能分布式训练**\n", + "\n", + "通过高度优化的Transformer网络实现,结合混合精度与Fleet分布式训练API,可充分利用GPU集群资源,高效完成预训练模型的分布式训练。\n", + "\n", + "### 1.2 BERT\n", + "\n", + "BERT的全称为Bidirectional Encoder Representations from Transformers,即基于Transformers的双向编码表示模型。BERT是Transformers应用的一次巨大的成功。在该模型提出时,其在NLP领域的11个方向上都大幅刷新了SOTA。其模型的主要特点可以归纳如下:\n", + "\n", + "1. 基于Transformer。Transformer的提出将注意力机制的应用发挥到了极致,同时也解决了基于RNN的注意力机制的无法并行计算的问题,使超大规模的模型训练在时间上变得可以接受;\n", + "\n", + "2. 双向编码。其实双向编码不是BERT首创,但是基于Transformer与双向编码结合使这一做法的效用得到了最充分的发挥;\n", + "\n", + "3. 使用MLM(Mask Language Model)和NSP(Next Sentence Prediction)实现多任务训练的目标。\n", + "\n", + "4. 迁移学习。BERT模型展现出了大规模数据训练带来的有效性,而更重要的一点是,BERT实质上是一种更好的语义表征,相较于经典的Word2Vec,Glove等模型具有更好词嵌入特征。在实际应用中,我们可以直接调用训练好的BERT模型作为特征表示,进而设计下游任务。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 2. 数据设置\n", + "在这一部分,我们将介绍使用的数据集,并展示数据集的调用方法。\n", + "\n", + "### 2.1 数据准备\n", + "诗歌数据集采用Github上开源的[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)。\n", + "\n", + "该数据集包含了唐宋两朝近一万四千古诗人, 接近5.5万首唐诗加26万宋诗. 两宋时期1564位词人,21050首词。其中,唐宋诗歌内容在json文件夹下,这里只使用json文件夹下的数据即可。以下式单个数据的示例:\n", + "```\n", + "{\n", + " \"author\":string\"胡宿\"\n", + " \"paragraphs\":[\n", + " \"五粒青松護翠苔,石門岑寂斷纖埃。\"\n", + " \"水浮花片知仙路,風遞鸞聲認嘯臺。\"\n", + " \"桐井曉寒千乳斂,茗園春嫩一旗開。\"\n", + " \"馳煙未勒山亭字,可是英靈許再來。\"\n", + " ]\n", + " \"title\":string\"沖虛觀\"\n", + " \"id\":string\"dad91d22-4b8a-4c04-a0d5-8f7ca8aff4de\"\n", + "}\n", + ",…]\n", + "```\n", + "\n", + "可见,此数据集中多数诗歌内容为繁体字。不过不用担心,飞桨已经内置了该数据集并且已经进行了简体化,我们可以通过简单的几行代码迅速调用该数据集!如下所示:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 更新paddlenlp版本\r\n", + "!pip install --upgrade paddlenlp" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test_dataset 的样本数量:364\n", + "dev_dataset 的样本数量:995\n", + "train_dataset 的样本数量:294598\n" + ] + } + ], + "source": [ + "import paddlenlp\r\n", + "test_dataset, dev_dataset, train_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('test','dev','train'), lazy=False)\r\n", + "print('test_dataset 的样本数量:%d'%len(test_dataset))\r\n", + "print('dev_dataset 的样本数量:%d'%len(dev_dataset))\r\n", + "print('train_dataset 的样本数量:%d'%len(train_dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "以上三个数据,train_dataset为训练集,test_dataset为测试集,dev_dataset为开发集。其中开发集用于训练过程的测试,以用来选择最合适的模型参数,避免模型过拟合。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 2.2 数据处理\n", + "如下为以上数据单样本的实例:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "单样本示例:{'tokens': '西\\x02风\\x02簇\\x02浪\\x02花\\x02,\\x02太\\x02湖\\x02连\\x02底\\x02冻\\x02。', 'labels': '冷\\x02照\\x02玉\\x02奁\\x02清\\x02,\\x02一\\x02片\\x02无\\x02瑕\\x02缝\\x02。\\x02面\\x02目\\x02分\\x02明\\x02,\\x02眼\\x02睛\\x02定\\x02动\\x02。\\x02不\\x02墯\\x02虚\\x02凝\\x02裂\\x02万\\x02差\\x02,\\x02漆\\x02桶\\x02漆\\x02桶\\x02。'}\n" + ] + } + ], + "source": [ + "print('单样本示例:%s'%test_dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "从单个样本的实例中可以看到,每个样本都有两句。为了方便处理,这里我们直接将两句合成一句进行训练。训练中我们将用每个诗句当前的字去预测下一个字,假设我们有样本sample, 那么我们的输入为sample\\[:-1\\],要预测的目标为sample\\[1:\\]。诗句中每个字后边都有符号'\\x02',由于对当前的训练并没有帮助,所以我们将其替换掉。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\r\n", + "def data_preprocess(dataset):\r\n", + " for i, data in enumerate(dataset):\r\n", + " dataset.data[i] = ''.join(list(dataset[i].values()))\r\n", + " dataset.data[i] = re.sub('\\x02', '', dataset[i])\r\n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "处理后的单样本示例:西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。\n" + ] + } + ], + "source": [ + "# 开始处理\r\n", + "test_dataset = data_preprocess(test_dataset)\r\n", + "dev_dataset = data_preprocess(dev_dataset)\r\n", + "train_dataset = data_preprocess(train_dataset)\r\n", + "print('处理后的单样本示例:%s'%test_dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "从PaddleNLP调用基于BERT预训练模型的分词工具,对诗歌进行分词和编码。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2021-06-09 15:35:07,276] [ INFO] - Downloading bert-base-chinese-vocab.txt from https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-chinese-vocab.txt\n", + "100%|██████████| 107/107 [00:00<00:00, 23081.18it/s]\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import BertTokenizer\r\n", + "\r\n", + "bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“\\[CLS\\]”标记(“\\[CLS\\]”是对一些特殊任务的留空项,对于需要此项功能的并需要标记语句开始的情况,一般会再加上“\\[BOS\\]”),在结尾添加“\\[SEP\\]”标记(需要区分句子的编码中,这个标记用来将不同的句子隔开,结尾添加“\\[EOS\\]”),这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“\\[UNK\\]”表示分词工具无法识别的符号,“\\[PAD\\]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。\n", + "[101, 6205, 7599, 5077, 3857, 5709, 8024, 1922, 3959, 6825, 2419, 1108, 511, 1107, 4212, 4373, 100, 3926, 8024, 671, 4275, 3187, 4442, 5361, 511, 7481, 4680, 1146, 3209, 8024, 4706, 4714, 2137, 1220, 511, 679, 100, 5994, 1125, 6162, 674, 2345, 8024, 4024, 3446, 4024, 3446, 511, 102]\n", + "[CLS]西风簇浪花,太湖连底冻。冷照玉[UNK]清,一片无瑕缝。面目分明,眼睛定动。不[UNK]虚凝裂万差,漆桶漆桶。[SEP]\n", + "大道分明在眼前,时人不会悮归泉。黄芽本是乾坤气,神水根基与汞连。\n", + "[101, 1920, 6887, 1146, 3209, 1762, 4706, 1184, 8024, 3198, 782, 679, 833, 100, 2495, 3787, 511, 7942, 5715, 3315, 3221, 746, 1787, 3698, 8024, 4868, 3717, 3418, 1825, 680, 3735, 6825, 511, 102]\n", + "[CLS]大道分明在眼前,时人不会[UNK]归泉。黄芽本是乾坤气,神水根基与汞连。[SEP]\n" + ] + } + ], + "source": [ + "# 处理效果展示\r\n", + "for poem in test_dataset[0:2]:\r\n", + " token_poem, _ = bert_tokenizer.encode(poem).values()\r\n", + " print(poem)\r\n", + " print(token_poem)\r\n", + " print(''.join(bert_tokenizer.convert_ids_to_tokens(token_poem)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 2.3 构造数据读取器\n", + "预处理数据后,我们基于飞桨2.0构造数据读取器,以适应后续模型的训练。\n", + "\n", + "在构造读取器之前,我们先来了解一下BERT模型的输入是什么样子的。如下图所示:\n", + "\n", + "![](https://ai-studio-static-online.cdn.bcebos.com/d35a235024fe431e8a3c3cac62064836e53e78577d8b42d4aaa806821c781c98)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "上图中可以清晰地显示出输入数据的具体样式,包括三个部分:Token Embeddings, Segment Embeddings, Position Embeddings。在这里,Embeddings理解为嵌入,即将一个元素表示成一个1 * n的向量的形式,用以表示这个元素在一个向量空间的相对位置。这是中文文本处理如今比较普遍采用的方式。在这里,Token Embeddings为词嵌入,将分词后的词元素映射成一个个1 * n的向量。除此之外,Segment Embeddings表示每个词元素属于何种角色。具体来说,当我们需要区分一个输入中不同语句时,如在对话模型中,区分输入中每一句话是哪个对象发出的,可以用Segment Embeddings。Position Embeddings为Transformer类模型的特色,由于此类自注意力机制无法区分距离的远近,引入了该嵌入来增加距离产生的偏置。通常情况下,Position为一个从句首到句尾渐增的数列,如\\[0,1,2,3,4,5,...,n-1\\]即表示一个长度为n的输入的Position。如何得到Embeddings呢?通常是构造一个N * n的矩阵,所有元素被唯一对应一个位置索引,元素数量不大于N。每一个元素的嵌入通过其对应的索引调取矩阵对应的行的n个列上的元素,即1 * n的向量。在这个项目中,由于不需要区分每一句的角色,Segment Embeddings可以设为一样的,即索引都为相同的值 (如0)。由于飞桨的BERT模型会自动处理Segment Embeddings和Position Embeddings,在构造输入的时候我们可以忽略这两项。在进行下一步计算前,所有类型的进行加和,每个词元素对应一个合成的嵌入向量。\n", + "\n", + "需注意以下类定义中包含填充内容,使输入样本对齐到一个特定的长度,以便于模型进行批处理运算。因此在得到数据读取器的实例时,需注意参数max_len,其不超过模型所支持的最大长度(PaddleNLP默认的序列最长长度为512)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import paddle\r\n", + "from paddle.io import Dataset\r\n", + "import numpy as np\r\n", + "\r\n", + "class PoemData(Dataset):\r\n", + " \"\"\"\r\n", + " 构造诗歌数据集,继承paddle.io.Dataset\r\n", + " Parameters:\r\n", + " poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码\r\n", + " max_len: 接收诗歌的最大长度\r\n", + " \"\"\"\r\n", + " def __init__(self, poems, tokenizer, max_len=128):\r\n", + " super(PoemData, self).__init__()\r\n", + " self.poems = poems\r\n", + " self.tokenizer = tokenizer\r\n", + " self.max_len = max_len\r\n", + " \r\n", + " def __getitem__(self, idx):\r\n", + " line = self.poems[idx]\r\n", + " token_line = self.tokenizer.encode(line)\r\n", + " token, token_type = token_line['input_ids'], token_line['token_type_ids']\r\n", + " if len(token) > self.max_len + 1:\r\n", + " token = token[:self.max_len] + token[-1:]\r\n", + " token_type = token_type[:self.max_len] + token_type[-1:]\r\n", + " input_token, input_token_type = token[:-1], token_type[:-1]\r\n", + " label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " # 输入填充\r\n", + " input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " input_pad_mask = (input_token != 0).astype('float32')\r\n", + " return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask\r\n", + " \r\n", + " def __len__(self):\r\n", + " return len(self.poems)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 3. 模型设置与训练\n", + "在这一部分,我们将快速搭建基于BERT预训练模型的古诗生成器,并对模型进行训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 3.1 预训练BERT模型\n", + "古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。\n", + "\n", + "进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。下表为一个示例:对于诗句“床前明月光,疑是地上霜。”来说,在训练的时候,输入为“床前明月光,疑是地上霜”(注意没有“。”),而预测的内容为输入的每个词对应的标签,我们把其预测标签设置为“前明月光,疑是地上霜。”在这里,我们可以理解为,文字“床”对应的标签为“前”、文字“前”对应的标签为“明”、......、文字“霜”对应的标签为“。”。因此,我们直接调用PaddleNLP的BERT词分类模型即可,需注意模型分类的类别为词表长度。\n", + "\n", + "|句子|床前明月光,疑是地上霜。|\n", + "|:--:|:--:|\n", + "|输入|床前明月光,疑是地上霜|\n", + "|预测|前明月光,疑是地上霜。|\n", + "|流程如下||\n", + "|根据内容:床|预测内容:前|\n", + "|根据内容:床前|预测内容:明|\n", + "|根据内容:床前明|预测内容:月|\n", + "|......|......|\n", + "|根据内容:床前明月光,疑是地上霜|预测内容:。|\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddlenlp.transformers import BertModel, BertForTokenClassification\r\n", + "from paddle.nn import Layer, Linear, Softmax\r\n", + "\r\n", + "class PoetryBertModel(Layer):\r\n", + " \"\"\"\r\n", + " 基于BERT预训练模型的诗歌生成模型\r\n", + " \"\"\"\r\n", + " def __init__(self, pretrained_bert_model: str, input_length: int):\r\n", + " super(PoetryBertModel, self).__init__()\r\n", + " bert_model = BertModel.from_pretrained(pretrained_bert_model)\r\n", + " self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape\r\n", + " self.bert_for_class = BertForTokenClassification(bert_model, self.vocab_size)\r\n", + " # 生成下三角矩阵,用来mask句子后边的信息\r\n", + " self.sequence_length = input_length\r\n", + " # lower_triangle_mask为input_length * input_length的下三角矩阵(包含主对角线),该掩码作为注意力掩码的一部分(在forward的\r\n", + " # 处理中为0的部分会被处理成无穷小量,以方便在计算注意力权重的时候保证被掩盖的部分权重约等于0)。而之所以写为下三角矩阵的形式,与\r\n", + " # transformer的多头注意力计算的机制有关,细节可以了解相关论文获悉。\r\n", + " self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", + "\r\n", + " def forward(self, token, token_type, input_mask, input_length=None):\r\n", + " # 计算attention mask\r\n", + " mask_left = paddle.reshape(input_mask, input_mask.shape + [1])\r\n", + " mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])\r\n", + " # 输入句子中有效的位置\r\n", + " mask_left = paddle.cast(mask_left, 'float32')\r\n", + " mask_right = paddle.cast(mask_right, 'float32')\r\n", + " attention_mask = paddle.matmul(mask_left, mask_right)\r\n", + " # 注意力机制计算中有效的位置\r\n", + " if input_length is not None:\r\n", + " # 之所以要再计算一次,是因为用于推理预测时,可能输入的长度不为实例化时设置的长度。这里的模型在训练时假设输入的\r\n", + " # 长度是被填充成一致的——这一步不是必须的,但是处理成一致长度比较方便处理(对应地,增加了显存的用度)。\r\n", + " lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", + " else:\r\n", + " lower_triangle_mask = self.lower_triangle_mask\r\n", + " attention_mask = attention_mask * lower_triangle_mask\r\n", + " # 无效的位置设为极小值\r\n", + " attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10\r\n", + " attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)\r\n", + "\r\n", + " output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)\r\n", + " \r\n", + " return output_logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 3.2 定义模型损失\n", + "由于真实值中有相当一部分是填充内容,我们需重写交叉熵损失,使其忽略填充内容带来的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class PoetryBertModelLossCriterion(Layer):\r\n", + " def forward(self, pred_logits, label, input_mask):\r\n", + " loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')\r\n", + " masked_loss = paddle.mean(loss * input_mask, axis=0)\r\n", + " return paddle.sum(masked_loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 3.3 模型准备\n", + "针对预训练模型的训练,需使用较小的学习率(learning_rate)进行调优。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2021-06-05 09:16:08,229] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams\n", + "[2021-06-05 09:16:17,229] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", + " return ufunc.reduce(obj, axis, dtype, out, **passkwargs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "========================================================================================================================================\n", + " Embedding-1 [[1, 128]] [1, 128, 768] 16,226,304 \n", + " Embedding-2 [[1, 128]] [1, 128, 768] 393,216 \n", + " Embedding-3 [[1, 128]] [1, 128, 768] 1,536 \n", + " LayerNorm-1 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Dropout-1 [[1, 128, 768]] [1, 128, 768] 0 \n", + " BertEmbeddings-1 [] [1, 128, 768] 0 \n", + " Linear-1 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-2 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-3 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-4 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-1 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-3 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-2 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-5 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-2 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-6 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-4 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-3 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-1 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-7 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-8 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-9 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-10 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-2 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-6 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-4 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-11 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-5 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-12 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-7 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-5 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-2 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-13 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-14 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-15 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-16 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-3 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-9 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-6 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-17 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-8 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-18 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-10 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-7 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-3 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-19 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-20 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-21 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-22 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-4 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-12 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-8 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-23 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-11 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-24 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-13 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-9 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-4 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-25 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-26 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-27 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-28 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-5 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-15 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-10 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-29 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-14 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-30 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-16 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-11 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-5 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-31 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-32 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-33 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-34 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-6 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-18 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-12 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-35 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-17 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-36 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-19 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-13 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-6 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-37 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-38 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-39 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-40 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-7 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-21 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-14 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-41 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-20 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-42 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-22 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-15 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-7 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-43 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-44 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-45 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-46 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-8 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-24 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-16 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-47 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-23 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-48 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-25 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-17 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-8 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-49 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-50 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-51 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-52 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-9 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-27 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-18 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-53 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-26 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-54 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-28 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-19 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-9 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-55 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-56 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-57 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-58 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-10 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-30 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-20 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-59 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-29 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-60 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-31 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-21 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-10 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-61 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-62 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-63 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-64 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-11 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-33 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-22 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-65 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-32 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-66 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-34 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-23 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-11 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-67 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-68 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-69 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-70 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-12 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-36 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-24 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-71 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-35 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-72 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-37 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-25 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-12 [[1, 128, 768]] [1, 128, 768] 0 \n", + " TransformerEncoder-1 [[1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Linear-73 [[1, 768]] [1, 768] 590,592 \n", + " Tanh-2 [[1, 768]] [1, 768] 0 \n", + " BertPooler-1 [[1, 128, 768]] [1, 768] 0 \n", + " BertModel-1 [[1, 128]] [[1, 128, 768], [1, 768]] 0 \n", + " Dropout-38 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-74 [[1, 128, 768]] [1, 128, 21128] 16,247,432 \n", + "BertForTokenClassification-1 [[1, 128]] [1, 128, 21128] 0 \n", + "========================================================================================================================================\n", + "Total params: 118,515,080\n", + "Trainable params: 118,515,080\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------------------------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 219.04\n", + "Params size (MB): 452.10\n", + "Estimated Total Size (MB): 671.14\n", + "----------------------------------------------------------------------------------------------------------------------------------------\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'total_params': 118515080, 'trainable_params': 118515080}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from paddle.static import InputSpec\r\n", + "from paddlenlp.metrics import Perplexity\r\n", + "from paddle.optimizer import AdamW\r\n", + "\r\n", + "net = PoetryBertModel('bert-base-chinese', 128)\r\n", + "\r\n", + "token_ids = InputSpec((-1, 128), 'int64', 'token')\r\n", + "token_type_ids = InputSpec((-1, 128), 'int64', 'token_type')\r\n", + "input_mask = InputSpec((-1, 128), 'float32', 'input_mask')\r\n", + "label = InputSpec((-1, 128), 'int64', 'label')\r\n", + "\r\n", + "inputs = [token_ids, token_type_ids, input_mask]\r\n", + "labels = [label, input_mask]\r\n", + "\r\n", + "model = paddle.Model(net, inputs, labels)\r\n", + "model.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])\r\n", + "\r\n", + "model.summary(inputs, [input.dtype for input in inputs])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 3.4 模型训练\n", + "由于调用了预训练模型,再次调优,只需很少轮的训练即可达到较好的效果。\n", + "\n", + "训练过程中,设置save_dir参数来保存训练的模型,并通过save_freq设置保存的频率。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddle.io import DataLoader\r\n", + "\r\n", + "train_loader = DataLoader(PoemData(train_dataset, bert_tokenizer, 128), batch_size=128, shuffle=True)\r\n", + "dev_loader = DataLoader(PoemData(dev_dataset, bert_tokenizer, 128), batch_size=32, shuffle=True)\r\n", + "model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1, eval_data=dev_loader, eval_freq=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 4. 古诗生成\n", + "以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为根据当前内容要预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。\n", + "\n", + "在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\r\n", + "\r\n", + "class PoetryGen(object):\r\n", + " \"\"\"\r\n", + " 定义一个自动生成诗句的类,按照要求生成诗句\r\n", + " model: 训练得到的预测模型\r\n", + " tokenizer: 分词编码工具\r\n", + " max_length: 生成诗句的最大长度,需小于等于model所允许的最大长度\r\n", + " \"\"\"\r\n", + " def __init__(self, model, tokenizer, max_length=512):\r\n", + " self.model = model\r\n", + " self.tokenizer = tokenizer\r\n", + " self.puncs = [',', '。', '?', ';']\r\n", + " self.max_length = max_length\r\n", + "\r\n", + " def generate(self, style='', head='', topk=2):\r\n", + " \"\"\"\r\n", + " 根据要求生成诗句\r\n", + " style (str): 生成诗句的风格,写成诗句的形式,如“大漠孤烟直,长河落日圆。”\r\n", + " head (str, list): 生成诗句的开头内容。若head为str格式,则head为诗句开始内容;\r\n", + " 若head为list格式,则head中每个元素为对应位置上诗句的开始内容(即藏头诗中的头)。\r\n", + " topk (int): 从预测的topk中选取结果\r\n", + " \"\"\"\r\n", + " head_index = 0\r\n", + " style_ids = self.tokenizer.encode(style)['input_ids']\r\n", + " # 去掉结束标记\r\n", + " style_ids = style_ids[:-1]\r\n", + " head_is_list = True if isinstance(head, list) else False\r\n", + " if head_is_list:\r\n", + " poetry_ids = self.tokenizer.encode(head[head_index])['input_ids']\r\n", + " else:\r\n", + " poetry_ids = self.tokenizer.encode(head)['input_ids']\r\n", + " # 去掉开始和结束标记\r\n", + " poetry_ids = poetry_ids[1:-1]\r\n", + " break_flag = False\r\n", + " while len(style_ids) + len(poetry_ids) <= self.max_length:\r\n", + " next_word = self._gen_next_word(style_ids + poetry_ids, topk)\r\n", + " # 对于一些符号,如[UNK], [PAD], [CLS]等,其产生后对诗句无意义,直接跳过\r\n", + " if next_word in self.tokenizer.convert_tokens_to_ids(['[UNK]', '[PAD]', '[CLS]']):\r\n", + " continue\r\n", + " if head_is_list:\r\n", + " if next_word in self.tokenizer.convert_tokens_to_ids(self.puncs):\r\n", + " head_index += 1\r\n", + " if head_index < len(head):\r\n", + " new_ids = self.tokenizer.encode(head[head_index])['input_ids']\r\n", + " new_ids = [next_word] + new_ids[1:-1]\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " break_flag = True\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " if next_word == self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0]:\r\n", + " break\r\n", + " poetry_ids += new_ids\r\n", + " if break_flag:\r\n", + " break\r\n", + " return ''.join(self.tokenizer.convert_ids_to_tokens(poetry_ids))\r\n", + "\r\n", + " def _gen_next_word(self, known_ids, topk):\r\n", + " type_token = [0] * len(known_ids)\r\n", + " mask = [1] * len(known_ids)\r\n", + " sequence_length = len(known_ids)\r\n", + " known_ids = paddle.to_tensor([known_ids], dtype='int64')\r\n", + " type_token = paddle.to_tensor([type_token], dtype='int64')\r\n", + " mask = paddle.to_tensor([mask], dtype='float32')\r\n", + " logits = self.model.network.forward(known_ids, type_token, mask, sequence_length)\r\n", + " # logits中对应最后一个词的输出即为下一个词的概率\r\n", + " words_prob = logits[0, -1, :].numpy()\r\n", + " # 依概率倒序排列后,选取前topk个词\r\n", + " words_to_be_choosen = words_prob.argsort()[::-1][:topk]\r\n", + " probs_to_be_choosen = words_prob[words_to_be_choosen]\r\n", + " # 归一化\r\n", + " probs_to_be_choosen = probs_to_be_choosen / sum(probs_to_be_choosen)\r\n", + " word_choosen = np.random.choice(words_to_be_choosen, p=probs_to_be_choosen)\r\n", + " return word_choosen\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 4.1 生成古诗示例" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 载入已经训练好的模型\r\n", + "net = PoetryBertModel('bert-base-chinese', 128)\r\n", + "model = paddle.Model(net)\r\n", + "model.load('./checkpoint/final')\r\n", + "poetry_gen = PoetryGen(model, bert_tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def poetry_show(poetry):\r\n", + " pattern = r\"([,。;?])\"\r\n", + " text = re.sub(pattern, r'\\1 ', poetry)\r\n", + " for p in text.split():\r\n", + " if p:\r\n", + " print(p)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "一雨一晴天气新,\n", + "春风桃李不胜春。\n", + "山中老去无多事,\n", + "莫道山花不是真。\n", + "山色不随人意好,\n", + "花枝只与鸟情邻。\n", + "何时得见东君面,\n", + "共醉花光醉一身。\n" + ] + } + ], + "source": [ + "# 随机生成一首诗\r\n", + "poetry = poetry_gen.generate()\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "云外有时生,\n", + "云间无限好?\n", + "月明风细细,\n", + "松响竹萧悄。\n", + "谁识此时情?\n", + "相看情未了。\n" + ] + } + ], + "source": [ + "# 生成特定风格的诗\r\n", + "poetry = poetry_gen.generate(style='会当凌绝顶,一览众山小。')\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "好好学习子,\n", + "不如癡爱官。\n", + "一身无定价,\n", + "百事有馀安。\n" + ] + } + ], + "source": [ + "# 生成特定开头的诗\r\n", + "poetry = poetry_gen.generate(head='好好学习')\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "飞来峰下白莲宫,\n", + "桨去帆来一叶东。\n", + "真境自然非世外,\n", + "好山长与白云通?\n" + ] + } + ], + "source": [ + "# 生成藏头诗\r\n", + "poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])\r\n", + "poetry_show(poetry)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}