From abffa64d95a9e2e36d6a36dc5f7c4388982672b2 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Tue, 24 Nov 2020 23:21:11 +0800 Subject: [PATCH 01/14] =?UTF-8?q?=E9=A3=9E=E6=A1=A82.0=E5=AE=9E=E4=BE=8B?= =?UTF-8?q?=E6=95=99=E7=A8=8B=E2=80=94=E2=80=94=E4=BD=BF=E7=94=A8=E9=A2=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E8=AF=8D=E5=90=91=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pretrained_word_embeddings.ipynb | 1484 +++++++++++++++++ 1 file changed, 1484 insertions(+) create mode 100644 paddle2.0_docs/pretrained_word_embeddings.ipynb diff --git a/paddle2.0_docs/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings.ipynb new file mode 100644 index 00000000..8242dc19 --- /dev/null +++ b/paddle2.0_docs/pretrained_word_embeddings.ipynb @@ -0,0 +1,1484 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 使用预训练的词向量\n", + "\n", + "Author: [Dongyang Yan](623320480@qq.com, github.com/fiyen )\n", + "\n", + "Data created: 2020/11/23\n", + "\n", + "Last modified: 2020/11/24\n", + "\n", + "Description: Tutorial to classify Imdb data using pre-trained word embeddings in paddlepaddle 2.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 摘要\n", + "\n", + "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imbd将直接调用自飞桨2.0,同时,\n", + "利用预训练的词向量([GloVe embedding](http://nlp.stanford.edu/projects/glove/))完成任务。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 环境设置" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import paddle as pd\r\n", + "from paddle.io import Dataset\r\n", + "import numpy as np\r\n", + "import paddle.text as pt\r\n", + "import random" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 用飞桨2.0调用Imdb数据集\n", + "由于飞桨2.0提供了经过处理的Imdb数据集,我们可以方便地调用所需要的数据实例,省去了数据预处理的麻烦。目前,飞桨2.0以及内置的高质量\n", + "数据集包括Conll05st、Imdb、Imikolov、Movielens、HCIHousing、WMT14和WMT16等,未来还将提供更多常用数据集的调用接口。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "imdb_train = pt.Imdb(mode='train', cutoff=150)\r\n", + "imdb_test = pt.Imdb(mode='test', cutoff=150)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "调用Imdb得到的是经过编码的内容。每个样本表示一个文档,以list的形式储存,list中的每个元素都由一个数字表示,对应文档相应位置的某个单词,\n", + "而单词和数字编码是一一对应的。其对应关系可以通过imdb_train.word_idx查看。我们可以检查一下以上生成的数据内容:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集样本数量: 25000; 测试集样本数量: 25000\n", + "样本标签: {0, 1}\n", + "样本字典: [(b'the', 0), (b'and', 1), (b'a', 2), (b'of', 3), (b'to', 4), (b'is', 5), (b'in', 6), (b'it', 7), (b'i', 8), (b'this', 9)]\n", + "单个样本: [5146, 43, 71, 6, 1092, 14, 0, 878, 130, 151, 5146, 18, 281, 747, 0, 5146, 3, 5146, 2165, 37, 5146, 46, 5, 71, 4089, 377, 162, 46, 5, 32, 1287, 300, 35, 203, 2136, 565, 14, 2, 253, 26, 146, 61, 372, 1, 615, 5146, 5, 30, 0, 50, 3290, 6, 2148, 14, 0, 5146, 11, 17, 451, 24, 4, 127, 10, 0, 878, 130, 43, 2, 50, 5146, 751, 5146, 5, 2, 221, 3727, 6, 9, 1167, 373, 9, 5, 5146, 7, 5, 1343, 13, 2, 5146, 1, 250, 7, 98, 4270, 56, 2316, 0, 928, 11, 11, 9, 16, 5, 5146, 5146, 6, 50, 69, 27, 280, 27, 108, 1045, 0, 2633, 4177, 3180, 17, 1675, 1, 2571]\n", + "最小样本长度: 10;最大样本长度: 2469\n" + ] + } + ], + "source": [ + "print(\"训练集样本数量: %d; 测试集样本数量: %d\" % (len(imdb_train), len(imdb_test)))\r\n", + "print(f\"样本标签: {set(imdb_train.labels)}\")\r\n", + "print(f\"样本字典: {list(imdb_train.word_idx.items())[:10]}\")\r\n", + "print(f\"单个样本: {imdb_train.docs[0]}\")\r\n", + "print(f\"最小样本长度: {min([len(x) for x in imdb_train.docs])};最大样本长度: {max([len(x) for x in imdb_train.docs])}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "以上参数中,cutoff定义了构建词典的截止大小,即数据集中出现频率在cutoff以下的不予考虑;mode定义了返回的数据用于何种用途(test: \n", + "测试集,train: 训练集)。对于训练集,我们将数据的顺序打乱,以优化将要进行的分类模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "shuffle_index = list(range(len(imdb_train)))\r\n", + "random.shuffle(shuffle_index)\r\n", + "train_x = [imdb_train.docs[i] for i in shuffle_index]\r\n", + "train_y = [imdb_train.labels[i] for i in shuffle_index]\r\n", + "\r\n", + "test_x = imdb_test.docs\r\n", + "test_y = imdb_test.labels" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "从样本长度上可以看到,每个样本的长度是不相同的。然而,在模型的训练过程中,需要保证每个样本的长度相同,以便于构造矩阵进行批量运算。\n", + "因此,我们需要先对所有样本进行填充或截断,使样本的长度一致。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def vectorizer(input, label=None, length=2000):\r\n", + " if label is not None:\r\n", + " for x, y in zip(input, label):\r\n", + " yield np.array((x + [0]*length)[:2000]).astype('int64'), np.array([y]).astype('int64')\r\n", + " else:\r\n", + " for x in input:\r\n", + " yield np.array((x + [0]*length)[:2000]).astype('int64')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 载入预训练向量。\n", + "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行\n", + "处理的方式进行匹配。这里略过此部分,如果感兴趣可以参考[此链接](https://aistudio.baidu.com/aistudio/projectdetail/496368)进一步了解。" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 下载预训练向量文件,此链接下载较慢,较快下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", + "!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", + "!unzip -q glove.6B.zip\r\n", + "\r\n", + "glove_path = \"./glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", + "embeddings = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "我们先观察上述GloVe预训练向量文件一行的数据:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GloVe单行数据:'the -0.038194 -0.24487 0.72812 -0.39961 0.083172 0.043953 -0.39141 0.3344 -0.57545 0.087459 0.28787 -0.06731 0.30906 -0.26384 -0.13231 -0.20757 0.33395 -0.33848 -0.31743 -0.48336 0.1464 -0.37304 0.34577 0.052041 0.44946 -0.46971 0.02628 -0.54155 -0.15518 -0.14107 -0.039722 0.28277 0.14393 0.23464 -0.31021 0.086173 0.20397 0.52624 0.17164 -0.082378 -0.71787 -0.41531 0.20335 -0.12763 0.41367 0.55187 0.57908 -0.33477 -0.36559 -0.54857 -0.062892 0.26584 0.30205 0.99775 -0.80481 -3.0243 0.01254 -0.36942 2.2167 0.72201 -0.24978 0.92136 0.034514 0.46745 1.1079 -0.19358 -0.074575 0.23353 -0.052062 -0.22044 0.057162 -0.15806 -0.30798 -0.41625 0.37972 0.15006 -0.53212 -0.2055 -1.2526 0.071624 0.70565 0.49744 -0.42063 0.26148 -1.538 -0.30223 -0.073438 -0.28312 0.37104 -0.25217 0.016215 -0.017099 -0.38984 0.87424 -0.72569 -0.51058 -0.52028 -0.1459 0.8278 0.27062\n", + "'\n" + ] + } + ], + "source": [ + "# 使用utf8编码解码\r\n", + "with open(glove_path, encoding='utf-8') as gf:\r\n", + " line = gf.readline()\r\n", + " print(\"GloVe单行数据:'%s'\" % line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "可以看到,每一行都以单词开头,其后接上该单词的向量值,各个值之间用空格隔开。基于此,可以用如下方法得到所有词向量的字典。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "预训练词向量总数:400000\n", + "单词'the'的向量是:[-0.038194, -0.24487, 0.72812, -0.39961, 0.083172, 0.043953, -0.39141, 0.3344, -0.57545, 0.087459, 0.28787, -0.06731, 0.30906, -0.26384, -0.13231, -0.20757, 0.33395, -0.33848, -0.31743, -0.48336, 0.1464, -0.37304, 0.34577, 0.052041, 0.44946, -0.46971, 0.02628, -0.54155, -0.15518, -0.14107, -0.039722, 0.28277, 0.14393, 0.23464, -0.31021, 0.086173, 0.20397, 0.52624, 0.17164, -0.082378, -0.71787, -0.41531, 0.20335, -0.12763, 0.41367, 0.55187, 0.57908, -0.33477, -0.36559, -0.54857, -0.062892, 0.26584, 0.30205, 0.99775, -0.80481, -3.0243, 0.01254, -0.36942, 2.2167, 0.72201, -0.24978, 0.92136, 0.034514, 0.46745, 1.1079, -0.19358, -0.074575, 0.23353, -0.052062, -0.22044, 0.057162, -0.15806, -0.30798, -0.41625, 0.37972, 0.15006, -0.53212, -0.2055, -1.2526, 0.071624, 0.70565, 0.49744, -0.42063, 0.26148, -1.538, -0.30223, -0.073438, -0.28312, 0.37104, -0.25217, 0.016215, -0.017099, -0.38984, 0.87424, -0.72569, -0.51058, -0.52028, -0.1459, 0.8278, 0.27062]\n" + ] + } + ], + "source": [ + "with open(glove_path, encoding='utf-8') as gf:\r\n", + " for glove in gf:\r\n", + " word, embedding = glove.split(maxsplit=1)\r\n", + " embedding = [float(s) for s in embedding.split(' ')]\r\n", + " embeddings[word] = embedding\r\n", + "print(\"预训练词向量总数:%d\" % len(embeddings))\r\n", + "print(f\"单词'the'的向量是:{embeddings['the']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 给数据集的词表匹配词向量\n", + "接下来,我们提取数据集的词表,需要注意的是,词表中的词编码的先后顺序是按照词出现的频率排列的,频率越高的词编码值越小。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "词表的前5个单词:[b'the', b'and', b'a', b'of', b'to']\n", + "词表的后5个单词:[b'troubles', b'virtual', b'warriors', b'widely', '']\n" + ] + } + ], + "source": [ + "word_idx = imdb_train.word_idx\r\n", + "vocab = [w for w in word_idx.keys()]\r\n", + "print(f\"词表的前5个单词:{vocab[:5]}\")\r\n", + "print(f\"词表的后5个单词:{vocab[-5:]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "观察词表的后5个单词,我们发现,最后一个词是\"\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", + "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('$$'并没有进行二进制编码,注意区分)。\n", + "接下来,我们给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,我们设该词的词\n", + "向量为零向量。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 定义词向量的维度,注意与预训练词向量保持一致\r\n", + "dim = 100\r\n", + "\r\n", + "vocab_embeddings = np.zeros((len(vocab), dim))\r\n", + "for ind, word in enumerate(vocab):\r\n", + " if word != '':\r\n", + " word = word.decode()\r\n", + " embedding = embeddings.get(word, np.zeros((dim,)))\r\n", + " vocab_embeddings[ind, :] = embedding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 构建基于预训练向量的Embedding\n", + "对于预训练向量的Embedding,我们一般期望它的参数不再变动,所以要设置trainable=False。如果希望在此基础上训练参数,则需要\n", + "设置trainable=True。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "pretrained_attr = pd.ParamAttr(name='embedding',\r\n", + " initializer=pd.nn.initializer.Assign(vocab_embeddings),\r\n", + " trainable=False)\r\n", + "embedding_layer = pd.nn.Embedding(num_embeddings=len(vocab),\r\n", + " embedding_dim=dim,\r\n", + " padding_idx=word_idx[''],\r\n", + " weight_attr=pretrained_attr)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 构建分类器\n", + "这里,我们构建简单的基于一维卷积的分类模型,其结构为:Embedding->Conv1D->Pool1D->Linear。在定义Linear时,由于需要知\n", + "道输入向量的维度,我们可以按照公式[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-beta/api/paddle/nn/layer/conv/Conv2d_cn.html)\n", + "来进行计算。这里给出计算的函数如下:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "===========================================================================\n", + " Embedding-1 [[1, 2000]] [1, 2000, 100] 514,700 \n", + " Conv1D-1 [[1, 2000, 100]] [1, 998, 10] 5,010 \n", + " ReLU-1 [[1, 998, 10]] [1, 998, 10] 0 \n", + " MaxPool1D-1 [[1, 998, 10]] [1, 998, 5] 0 \n", + " Flatten-1 [[1, 998, 5]] [1, 4990] 0 \n", + " Linear-1 [[1, 4990]] [1, 2] 9,982 \n", + " Softmax-1 [[1, 2]] [1, 2] 0 \n", + "===========================================================================\n", + "Total params: 529,692\n", + "Trainable params: 529,692\n", + "Non-trainable params: 0\n", + "---------------------------------------------------------------------------\n", + "Input size (MB): 0.01\n", + "Forward/backward pass size (MB): 1.75\n", + "Params size (MB): 2.02\n", + "Estimated Total Size (MB): 3.78\n", + "---------------------------------------------------------------------------\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'total_params': 529692, 'trainable_params': 529692}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def cal_output_shape(input_shape, out_channels, kernel_size, stride, padding=0, dilation=1):\r\n", + " return out_channels, int((input_shape + 2*padding - (dilation*(kernel_size - 1) + 1)) / stride) + 1\r\n", + "\r\n", + "\r\n", + "# 定义每个样本的长度\r\n", + "length = 2000\r\n", + "\r\n", + "# 定义卷积层参数\r\n", + "kernel_size = 5\r\n", + "out_channels = 10\r\n", + "stride = 2\r\n", + "padding = 0\r\n", + "\r\n", + "output_shape = cal_output_shape(length, out_channels, kernel_size, stride, padding)\r\n", + "output_shape = cal_output_shape(output_shape[1], output_shape[0], 2, 2, 0)\r\n", + "sim_model = pd.nn.Sequential(embedding_layer,\r\n", + " pd.nn.Conv1D(in_channels=dim, out_channels=out_channels, kernel_size=kernel_size,\r\n", + " stride=stride, padding=padding, data_format='NLC', bias_attr=True),\r\n", + " pd.nn.ReLU(),\r\n", + " pd.nn.MaxPool1D(kernel_size=2, stride=2),\r\n", + " pd.nn.Flatten(),\r\n", + " pd.nn.Linear(in_features=np.prod(output_shape), out_features=2, bias_attr=True),\r\n", + " pd.nn.Softmax())\r\n", + "\r\n", + "pd.summary(sim_model, input_size=(-1, length), dtypes='int64')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 读取数据,进行训练\n", + "我们可以利用飞桨2.0的io.Dataset模块来构建一个数据的读取器,方便地将数据进行分批训练。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "step 10/586 - loss: 0.8757 - acc: 0.4813 - 18ms/step\n", + "step 20/586 - loss: 0.8331 - acc: 0.4828 - 13ms/step\n", + "step 30/586 - loss: 0.6944 - acc: 0.5042 - 11ms/step\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", + " return (isinstance(seq, collections.Sequence) and\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 40/586 - loss: 0.7220 - acc: 0.5070 - 10ms/step\n", + "step 50/586 - loss: 0.6808 - acc: 0.4981 - 9ms/step\n", + "step 60/586 - loss: 0.7056 - acc: 0.5010 - 9ms/step\n", + "step 70/586 - loss: 0.6920 - acc: 0.5004 - 8ms/step\n", + "step 80/586 - loss: 0.6837 - acc: 0.5035 - 8ms/step\n", + "step 90/586 - loss: 0.6995 - acc: 0.4997 - 8ms/step\n", + "step 100/586 - loss: 0.6805 - acc: 0.5056 - 8ms/step\n", + "step 110/586 - loss: 0.6981 - acc: 0.5051 - 8ms/step\n", + "step 120/586 - loss: 0.7033 - acc: 0.5070 - 8ms/step\n", + "step 130/586 - loss: 0.7437 - acc: 0.5108 - 8ms/step\n", + "step 140/586 - loss: 0.6721 - acc: 0.5109 - 8ms/step\n", + "step 150/586 - loss: 0.6856 - acc: 0.5083 - 7ms/step\n", + "step 160/586 - loss: 0.6862 - acc: 0.5119 - 7ms/step\n", + "step 170/586 - loss: 0.6881 - acc: 0.5132 - 7ms/step\n", + "step 180/586 - loss: 0.6655 - acc: 0.5141 - 7ms/step\n", + "step 190/586 - loss: 0.6620 - acc: 0.5155 - 7ms/step\n", + "step 200/586 - loss: 0.6299 - acc: 0.5219 - 7ms/step\n", + "step 210/586 - loss: 0.7355 - acc: 0.5228 - 7ms/step\n", + "step 220/586 - loss: 0.6562 - acc: 0.5267 - 7ms/step\n", + "step 230/586 - loss: 0.6495 - acc: 0.5318 - 7ms/step\n", + "step 240/586 - loss: 0.6333 - acc: 0.5375 - 7ms/step\n", + "step 250/586 - loss: 0.6000 - acc: 0.5427 - 8ms/step\n", + "step 260/586 - loss: 0.5711 - acc: 0.5496 - 8ms/step\n", + "step 270/586 - loss: 0.5693 - acc: 0.5546 - 8ms/step\n", + "step 280/586 - loss: 0.6908 - acc: 0.5616 - 8ms/step\n", + "step 290/586 - loss: 0.6217 - acc: 0.5685 - 8ms/step\n", + "step 300/586 - loss: 0.5417 - acc: 0.5743 - 8ms/step\n", + "step 310/586 - loss: 0.5207 - acc: 0.5780 - 8ms/step\n", + "step 320/586 - loss: 0.5410 - acc: 0.5841 - 8ms/step\n", + "step 330/586 - loss: 0.5647 - acc: 0.5883 - 8ms/step\n", + "step 340/586 - loss: 0.4975 - acc: 0.5930 - 8ms/step\n", + "step 350/586 - loss: 0.5611 - acc: 0.5988 - 8ms/step\n", + "step 360/586 - loss: 0.5176 - acc: 0.6044 - 8ms/step\n", + "step 370/586 - loss: 0.4878 - acc: 0.6087 - 8ms/step\n", + "step 380/586 - loss: 0.5079 - acc: 0.6131 - 8ms/step\n", + "step 390/586 - loss: 0.4918 - acc: 0.6178 - 8ms/step\n", + "step 400/586 - loss: 0.4999 - acc: 0.6220 - 8ms/step\n", + "step 410/586 - loss: 0.5087 - acc: 0.6254 - 8ms/step\n", + "step 420/586 - loss: 0.4500 - acc: 0.6286 - 8ms/step\n", + "step 430/586 - loss: 0.4677 - acc: 0.6338 - 8ms/step\n", + "step 440/586 - loss: 0.4354 - acc: 0.6377 - 8ms/step\n", + "step 450/586 - loss: 0.4049 - acc: 0.6424 - 8ms/step\n", + "step 460/586 - loss: 0.4874 - acc: 0.6459 - 8ms/step\n", + "step 470/586 - loss: 0.6287 - acc: 0.6497 - 8ms/step\n", + "step 480/586 - loss: 0.4633 - acc: 0.6535 - 8ms/step\n", + "step 490/586 - loss: 0.4972 - acc: 0.6573 - 8ms/step\n", + "step 500/586 - loss: 0.5369 - acc: 0.6603 - 8ms/step\n", + "step 510/586 - loss: 0.5170 - acc: 0.6634 - 8ms/step\n", + "step 520/586 - loss: 0.4569 - acc: 0.6665 - 8ms/step\n", + "step 530/586 - loss: 0.4837 - acc: 0.6696 - 8ms/step\n", + "step 540/586 - loss: 0.4510 - acc: 0.6726 - 8ms/step\n", + "step 550/586 - loss: 0.5162 - acc: 0.6756 - 8ms/step\n", + "step 560/586 - loss: 0.4821 - acc: 0.6781 - 8ms/step\n", + "step 570/586 - loss: 0.4589 - acc: 0.6806 - 8ms/step\n", + "step 580/586 - loss: 0.4688 - acc: 0.6830 - 8ms/step\n", + "step 586/586 - loss: 0.4162 - acc: 0.6847 - 8ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4399 - acc: 0.8313 - 3ms/step\n", + "step 20/196 - loss: 0.4896 - acc: 0.8266 - 2ms/step\n", + "step 30/196 - loss: 0.6432 - acc: 0.8187 - 2ms/step\n", + "step 40/196 - loss: 0.4953 - acc: 0.8156 - 2ms/step\n", + "step 50/196 - loss: 0.4499 - acc: 0.8081 - 2ms/step\n", + "step 60/196 - loss: 0.4401 - acc: 0.8130 - 2ms/step\n", + "step 70/196 - loss: 0.4320 - acc: 0.8121 - 2ms/step\n", + "step 80/196 - loss: 0.5158 - acc: 0.8102 - 2ms/step\n", + "step 90/196 - loss: 0.6223 - acc: 0.8115 - 2ms/step\n", + "step 100/196 - loss: 0.4908 - acc: 0.8172 - 2ms/step\n", + "step 110/196 - loss: 0.4968 - acc: 0.8173 - 2ms/step\n", + "step 120/196 - loss: 0.4446 - acc: 0.8161 - 2ms/step\n", + "step 130/196 - loss: 0.4763 - acc: 0.8159 - 2ms/step\n", + "step 140/196 - loss: 0.4702 - acc: 0.8174 - 2ms/step\n", + "step 150/196 - loss: 0.5083 - acc: 0.8163 - 2ms/step\n", + "step 160/196 - loss: 0.5015 - acc: 0.8139 - 2ms/step\n", + "step 170/196 - loss: 0.5416 - acc: 0.8116 - 2ms/step\n", + "step 180/196 - loss: 0.4286 - acc: 0.8120 - 2ms/step\n", + "step 190/196 - loss: 0.5156 - acc: 0.8123 - 2ms/step\n", + "step 196/196 - loss: 0.5552 - acc: 0.8122 - 2ms/step\n", + "Eval samples: 6250\n", + "Epoch 2/10\n", + "step 10/586 - loss: 0.4843 - acc: 0.8375 - 7ms/step\n", + "step 20/586 - loss: 0.4507 - acc: 0.8516 - 7ms/step\n", + "step 30/586 - loss: 0.5005 - acc: 0.8521 - 7ms/step\n", + "step 40/586 - loss: 0.4608 - acc: 0.8531 - 7ms/step\n", + "step 50/586 - loss: 0.4466 - acc: 0.8481 - 7ms/step\n", + "step 60/586 - loss: 0.5826 - acc: 0.8406 - 7ms/step\n", + "step 70/586 - loss: 0.4946 - acc: 0.8415 - 7ms/step\n", + "step 80/586 - loss: 0.4346 - acc: 0.8410 - 7ms/step\n", + "step 90/586 - loss: 0.4112 - acc: 0.8465 - 7ms/step\n", + "step 100/586 - loss: 0.4780 - acc: 0.8472 - 7ms/step\n", + "step 110/586 - loss: 0.4085 - acc: 0.8477 - 7ms/step\n", + "step 120/586 - loss: 0.4291 - acc: 0.8490 - 7ms/step\n", + "step 130/586 - loss: 0.4203 - acc: 0.8498 - 7ms/step\n", + "step 140/586 - loss: 0.4696 - acc: 0.8496 - 7ms/step\n", + "step 150/586 - loss: 0.4195 - acc: 0.8502 - 7ms/step\n", + "step 160/586 - loss: 0.4378 - acc: 0.8520 - 7ms/step\n", + "step 170/586 - loss: 0.4465 - acc: 0.8528 - 7ms/step\n", + "step 180/586 - loss: 0.4533 - acc: 0.8535 - 7ms/step\n", + "step 190/586 - loss: 0.4143 - acc: 0.8556 - 7ms/step\n", + "step 200/586 - loss: 0.4385 - acc: 0.8567 - 7ms/step\n", + "step 210/586 - loss: 0.4712 - acc: 0.8580 - 7ms/step\n", + "step 220/586 - loss: 0.4541 - acc: 0.8587 - 7ms/step\n", + "step 230/586 - loss: 0.5102 - acc: 0.8598 - 7ms/step\n", + "step 240/586 - loss: 0.4461 - acc: 0.8604 - 7ms/step\n", + "step 250/586 - loss: 0.4888 - acc: 0.8598 - 7ms/step\n", + "step 260/586 - loss: 0.4808 - acc: 0.8594 - 7ms/step\n", + "step 270/586 - loss: 0.3762 - acc: 0.8600 - 7ms/step\n", + "step 280/586 - loss: 0.4755 - acc: 0.8609 - 7ms/step\n", + "step 290/586 - loss: 0.4851 - acc: 0.8610 - 7ms/step\n", + "step 300/586 - loss: 0.4570 - acc: 0.8615 - 7ms/step\n", + "step 310/586 - loss: 0.4403 - acc: 0.8611 - 7ms/step\n", + "step 320/586 - loss: 0.3967 - acc: 0.8611 - 7ms/step\n", + "step 330/586 - loss: 0.5665 - acc: 0.8614 - 7ms/step\n", + "step 340/586 - loss: 0.4581 - acc: 0.8616 - 7ms/step\n", + "step 350/586 - loss: 0.4790 - acc: 0.8614 - 7ms/step\n", + "step 360/586 - loss: 0.4301 - acc: 0.8619 - 7ms/step\n", + "step 370/586 - loss: 0.4055 - acc: 0.8617 - 7ms/step\n", + "step 380/586 - loss: 0.3873 - acc: 0.8626 - 7ms/step\n", + "step 390/586 - loss: 0.3884 - acc: 0.8635 - 7ms/step\n", + "step 400/586 - loss: 0.3815 - acc: 0.8634 - 7ms/step\n", + "step 410/586 - loss: 0.4561 - acc: 0.8633 - 7ms/step\n", + "step 420/586 - loss: 0.4677 - acc: 0.8631 - 7ms/step\n", + "step 430/586 - loss: 0.4463 - acc: 0.8624 - 7ms/step\n", + "step 440/586 - loss: 0.4642 - acc: 0.8624 - 7ms/step\n", + "step 450/586 - loss: 0.4780 - acc: 0.8626 - 7ms/step\n", + "step 460/586 - loss: 0.4521 - acc: 0.8627 - 7ms/step\n", + "step 470/586 - loss: 0.4318 - acc: 0.8628 - 7ms/step\n", + "step 480/586 - loss: 0.4390 - acc: 0.8628 - 7ms/step\n", + "step 490/586 - loss: 0.4787 - acc: 0.8629 - 7ms/step\n", + "step 500/586 - loss: 0.4620 - acc: 0.8631 - 7ms/step\n", + "step 510/586 - loss: 0.5165 - acc: 0.8631 - 7ms/step\n", + "step 520/586 - loss: 0.4316 - acc: 0.8623 - 7ms/step\n", + "step 530/586 - loss: 0.3964 - acc: 0.8627 - 7ms/step\n", + "step 540/586 - loss: 0.4333 - acc: 0.8631 - 7ms/step\n", + "step 550/586 - loss: 0.3577 - acc: 0.8629 - 7ms/step\n", + "step 560/586 - loss: 0.4475 - acc: 0.8631 - 7ms/step\n", + "step 570/586 - loss: 0.3820 - acc: 0.8634 - 7ms/step\n", + "step 580/586 - loss: 0.4899 - acc: 0.8636 - 7ms/step\n", + "step 586/586 - loss: 0.3425 - acc: 0.8641 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4062 - acc: 0.8781 - 3ms/step\n", + "step 20/196 - loss: 0.4372 - acc: 0.8781 - 3ms/step\n", + "step 30/196 - loss: 0.5886 - acc: 0.8750 - 3ms/step\n", + "step 40/196 - loss: 0.4661 - acc: 0.8648 - 3ms/step\n", + "step 50/196 - loss: 0.4340 - acc: 0.8612 - 3ms/step\n", + "step 60/196 - loss: 0.4301 - acc: 0.8604 - 3ms/step\n", + "step 70/196 - loss: 0.4055 - acc: 0.8616 - 3ms/step\n", + "step 80/196 - loss: 0.4645 - acc: 0.8590 - 3ms/step\n", + "step 90/196 - loss: 0.5809 - acc: 0.8597 - 3ms/step\n", + "step 100/196 - loss: 0.4399 - acc: 0.8606 - 3ms/step\n", + "step 110/196 - loss: 0.4577 - acc: 0.8608 - 3ms/step\n", + "step 120/196 - loss: 0.3500 - acc: 0.8581 - 3ms/step\n", + "step 130/196 - loss: 0.4330 - acc: 0.8587 - 3ms/step\n", + "step 140/196 - loss: 0.4096 - acc: 0.8603 - 3ms/step\n", + "step 150/196 - loss: 0.4189 - acc: 0.8602 - 3ms/step\n", + "step 160/196 - loss: 0.4849 - acc: 0.8588 - 3ms/step\n", + "step 170/196 - loss: 0.4570 - acc: 0.8590 - 3ms/step\n", + "step 180/196 - loss: 0.3667 - acc: 0.8601 - 3ms/step\n", + "step 190/196 - loss: 0.4623 - acc: 0.8604 - 3ms/step\n", + "step 196/196 - loss: 0.5284 - acc: 0.8619 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 3/10\n", + "step 10/586 - loss: 0.4269 - acc: 0.8875 - 7ms/step\n", + "step 20/586 - loss: 0.3295 - acc: 0.9031 - 7ms/step\n", + "step 30/586 - loss: 0.4543 - acc: 0.9062 - 7ms/step\n", + "step 40/586 - loss: 0.3627 - acc: 0.9102 - 7ms/step\n", + "step 50/586 - loss: 0.4724 - acc: 0.9087 - 7ms/step\n", + "step 60/586 - loss: 0.4065 - acc: 0.9104 - 7ms/step\n", + "step 70/586 - loss: 0.3910 - acc: 0.9134 - 7ms/step\n", + "step 80/586 - loss: 0.4536 - acc: 0.9086 - 7ms/step\n", + "step 90/586 - loss: 0.4164 - acc: 0.9052 - 7ms/step\n", + "step 100/586 - loss: 0.5490 - acc: 0.8994 - 7ms/step\n", + "step 110/586 - loss: 0.4750 - acc: 0.8952 - 7ms/step\n", + "step 120/586 - loss: 0.3541 - acc: 0.8964 - 7ms/step\n", + "step 130/586 - loss: 0.3955 - acc: 0.8974 - 7ms/step\n", + "step 140/586 - loss: 0.4073 - acc: 0.8971 - 7ms/step\n", + "step 150/586 - loss: 0.4303 - acc: 0.8985 - 7ms/step\n", + "step 160/586 - loss: 0.4012 - acc: 0.8984 - 7ms/step\n", + "step 170/586 - loss: 0.4510 - acc: 0.8987 - 7ms/step\n", + "step 180/586 - loss: 0.4806 - acc: 0.8993 - 7ms/step\n", + "step 190/586 - loss: 0.4275 - acc: 0.8998 - 7ms/step\n", + "step 200/586 - loss: 0.4005 - acc: 0.8995 - 7ms/step\n", + "step 210/586 - loss: 0.4164 - acc: 0.8994 - 7ms/step\n", + "step 220/586 - loss: 0.4389 - acc: 0.8999 - 7ms/step\n", + "step 230/586 - loss: 0.4320 - acc: 0.9003 - 7ms/step\n", + "step 240/586 - loss: 0.4554 - acc: 0.8995 - 7ms/step\n", + "step 250/586 - loss: 0.4506 - acc: 0.8986 - 7ms/step\n", + "step 260/586 - loss: 0.3554 - acc: 0.8987 - 7ms/step\n", + "step 270/586 - loss: 0.4138 - acc: 0.8992 - 7ms/step\n", + "step 280/586 - loss: 0.3524 - acc: 0.8987 - 7ms/step\n", + "step 290/586 - loss: 0.3577 - acc: 0.8995 - 7ms/step\n", + "step 300/586 - loss: 0.3739 - acc: 0.8996 - 7ms/step\n", + "step 310/586 - loss: 0.3896 - acc: 0.8996 - 7ms/step\n", + "step 320/586 - loss: 0.3983 - acc: 0.9000 - 7ms/step\n", + "step 330/586 - loss: 0.4169 - acc: 0.9001 - 7ms/step\n", + "step 340/586 - loss: 0.4219 - acc: 0.8982 - 7ms/step\n", + "step 350/586 - loss: 0.5360 - acc: 0.8988 - 7ms/step\n", + "step 360/586 - loss: 0.3557 - acc: 0.8984 - 7ms/step\n", + "step 370/586 - loss: 0.4556 - acc: 0.8978 - 7ms/step\n", + "step 380/586 - loss: 0.3822 - acc: 0.8975 - 7ms/step\n", + "step 390/586 - loss: 0.4795 - acc: 0.8967 - 7ms/step\n", + "step 400/586 - loss: 0.4399 - acc: 0.8965 - 7ms/step\n", + "step 410/586 - loss: 0.4165 - acc: 0.8963 - 7ms/step\n", + "step 420/586 - loss: 0.4211 - acc: 0.8968 - 7ms/step\n", + "step 430/586 - loss: 0.3752 - acc: 0.8971 - 7ms/step\n", + "step 440/586 - loss: 0.4722 - acc: 0.8962 - 7ms/step\n", + "step 450/586 - loss: 0.3402 - acc: 0.8963 - 7ms/step\n", + "step 460/586 - loss: 0.4418 - acc: 0.8967 - 7ms/step\n", + "step 470/586 - loss: 0.3263 - acc: 0.8975 - 7ms/step\n", + "step 480/586 - loss: 0.3991 - acc: 0.8974 - 7ms/step\n", + "step 490/586 - loss: 0.3989 - acc: 0.8979 - 7ms/step\n", + "step 500/586 - loss: 0.4587 - acc: 0.8978 - 7ms/step\n", + "step 510/586 - loss: 0.3556 - acc: 0.8975 - 7ms/step\n", + "step 520/586 - loss: 0.4912 - acc: 0.8977 - 7ms/step\n", + "step 530/586 - loss: 0.4094 - acc: 0.8979 - 7ms/step\n", + "step 540/586 - loss: 0.3773 - acc: 0.8984 - 7ms/step\n", + "step 550/586 - loss: 0.4833 - acc: 0.8980 - 7ms/step\n", + "step 560/586 - loss: 0.3811 - acc: 0.8980 - 7ms/step\n", + "step 570/586 - loss: 0.4198 - acc: 0.8978 - 7ms/step\n", + "step 580/586 - loss: 0.3985 - acc: 0.8984 - 7ms/step\n", + "step 586/586 - loss: 0.4302 - acc: 0.8987 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4235 - acc: 0.8531 - 3ms/step\n", + "step 20/196 - loss: 0.4380 - acc: 0.8562 - 3ms/step\n", + "step 30/196 - loss: 0.5421 - acc: 0.8583 - 3ms/step\n", + "step 40/196 - loss: 0.4682 - acc: 0.8562 - 3ms/step\n", + "step 50/196 - loss: 0.4120 - acc: 0.8588 - 3ms/step\n", + "step 60/196 - loss: 0.3863 - acc: 0.8589 - 3ms/step\n", + "step 70/196 - loss: 0.4057 - acc: 0.8634 - 3ms/step\n", + "step 80/196 - loss: 0.4562 - acc: 0.8633 - 3ms/step\n", + "step 90/196 - loss: 0.5596 - acc: 0.8632 - 3ms/step\n", + "step 100/196 - loss: 0.4493 - acc: 0.8653 - 3ms/step\n", + "step 110/196 - loss: 0.4656 - acc: 0.8639 - 3ms/step\n", + "step 120/196 - loss: 0.3922 - acc: 0.8604 - 3ms/step\n", + "step 130/196 - loss: 0.4482 - acc: 0.8608 - 3ms/step\n", + "step 140/196 - loss: 0.3829 - acc: 0.8632 - 3ms/step\n", + "step 150/196 - loss: 0.4171 - acc: 0.8638 - 3ms/step\n", + "step 160/196 - loss: 0.4876 - acc: 0.8615 - 3ms/step\n", + "step 170/196 - loss: 0.4649 - acc: 0.8608 - 3ms/step\n", + "step 180/196 - loss: 0.3737 - acc: 0.8627 - 3ms/step\n", + "step 190/196 - loss: 0.4659 - acc: 0.8620 - 3ms/step\n", + "step 196/196 - loss: 0.4331 - acc: 0.8634 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 4/10\n", + "step 10/586 - loss: 0.4649 - acc: 0.8938 - 7ms/step\n", + "step 20/586 - loss: 0.4502 - acc: 0.8891 - 7ms/step\n", + "step 30/586 - loss: 0.3967 - acc: 0.8969 - 7ms/step\n", + "step 40/586 - loss: 0.3733 - acc: 0.9000 - 7ms/step\n", + "step 50/586 - loss: 0.4118 - acc: 0.9094 - 7ms/step\n", + "step 60/586 - loss: 0.3935 - acc: 0.9094 - 7ms/step\n", + "step 70/586 - loss: 0.3910 - acc: 0.9125 - 7ms/step\n", + "step 80/586 - loss: 0.3524 - acc: 0.9168 - 7ms/step\n", + "step 90/586 - loss: 0.3936 - acc: 0.9184 - 7ms/step\n", + "step 100/586 - loss: 0.3414 - acc: 0.9219 - 7ms/step\n", + "step 110/586 - loss: 0.3739 - acc: 0.9244 - 7ms/step\n", + "step 120/586 - loss: 0.4057 - acc: 0.9237 - 7ms/step\n", + "step 130/586 - loss: 0.3796 - acc: 0.9226 - 7ms/step\n", + "step 140/586 - loss: 0.3649 - acc: 0.9219 - 7ms/step\n", + "step 150/586 - loss: 0.3848 - acc: 0.9208 - 7ms/step\n", + "step 160/586 - loss: 0.4138 - acc: 0.9207 - 7ms/step\n", + "step 170/586 - loss: 0.3893 - acc: 0.9219 - 7ms/step\n", + "step 180/586 - loss: 0.3575 - acc: 0.9229 - 7ms/step\n", + "step 190/586 - loss: 0.3528 - acc: 0.9248 - 7ms/step\n", + "step 200/586 - loss: 0.4436 - acc: 0.9231 - 7ms/step\n", + "step 210/586 - loss: 0.3936 - acc: 0.9232 - 7ms/step\n", + "step 220/586 - loss: 0.3917 - acc: 0.9213 - 7ms/step\n", + "step 230/586 - loss: 0.3866 - acc: 0.9219 - 7ms/step\n", + "step 240/586 - loss: 0.4124 - acc: 0.9224 - 7ms/step\n", + "step 250/586 - loss: 0.4374 - acc: 0.9215 - 7ms/step\n", + "step 260/586 - loss: 0.3602 - acc: 0.9218 - 7ms/step\n", + "step 270/586 - loss: 0.3354 - acc: 0.9223 - 7ms/step\n", + "step 280/586 - loss: 0.4723 - acc: 0.9220 - 7ms/step\n", + "step 290/586 - loss: 0.3258 - acc: 0.9230 - 7ms/step\n", + "step 300/586 - loss: 0.3674 - acc: 0.9236 - 7ms/step\n", + "step 310/586 - loss: 0.3226 - acc: 0.9241 - 6ms/step\n", + "step 320/586 - loss: 0.3961 - acc: 0.9241 - 6ms/step\n", + "step 330/586 - loss: 0.4282 - acc: 0.9237 - 6ms/step\n", + "step 340/586 - loss: 0.3943 - acc: 0.9235 - 6ms/step\n", + "step 350/586 - loss: 0.4288 - acc: 0.9224 - 6ms/step\n", + "step 360/586 - loss: 0.4189 - acc: 0.9221 - 7ms/step\n", + "step 370/586 - loss: 0.4015 - acc: 0.9227 - 7ms/step\n", + "step 380/586 - loss: 0.3946 - acc: 0.9230 - 7ms/step\n", + "step 390/586 - loss: 0.3763 - acc: 0.9233 - 7ms/step\n", + "step 400/586 - loss: 0.3684 - acc: 0.9232 - 7ms/step\n", + "step 410/586 - loss: 0.3471 - acc: 0.9233 - 7ms/step\n", + "step 420/586 - loss: 0.4221 - acc: 0.9234 - 7ms/step\n", + "step 430/586 - loss: 0.4527 - acc: 0.9232 - 7ms/step\n", + "step 440/586 - loss: 0.3835 - acc: 0.9233 - 7ms/step\n", + "step 450/586 - loss: 0.4414 - acc: 0.9233 - 7ms/step\n", + "step 460/586 - loss: 0.3542 - acc: 0.9235 - 7ms/step\n", + "step 470/586 - loss: 0.3878 - acc: 0.9236 - 7ms/step\n", + "step 480/586 - loss: 0.4531 - acc: 0.9235 - 7ms/step\n", + "step 490/586 - loss: 0.4480 - acc: 0.9234 - 7ms/step\n", + "step 500/586 - loss: 0.3302 - acc: 0.9239 - 7ms/step\n", + "step 510/586 - loss: 0.3513 - acc: 0.9238 - 7ms/step\n", + "step 520/586 - loss: 0.4588 - acc: 0.9237 - 7ms/step\n", + "step 530/586 - loss: 0.3953 - acc: 0.9238 - 7ms/step\n", + "step 540/586 - loss: 0.4340 - acc: 0.9242 - 7ms/step\n", + "step 550/586 - loss: 0.3836 - acc: 0.9243 - 7ms/step\n", + "step 560/586 - loss: 0.3799 - acc: 0.9241 - 7ms/step\n", + "step 570/586 - loss: 0.4244 - acc: 0.9240 - 7ms/step\n", + "step 580/586 - loss: 0.3150 - acc: 0.9236 - 7ms/step\n", + "step 586/586 - loss: 0.5743 - acc: 0.9230 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.3942 - acc: 0.8906 - 2ms/step\n", + "step 20/196 - loss: 0.4010 - acc: 0.8891 - 2ms/step\n", + "step 30/196 - loss: 0.5784 - acc: 0.8750 - 2ms/step\n", + "step 40/196 - loss: 0.4673 - acc: 0.8703 - 2ms/step\n", + "step 50/196 - loss: 0.4671 - acc: 0.8669 - 2ms/step\n", + "step 60/196 - loss: 0.4023 - acc: 0.8656 - 2ms/step\n", + "step 70/196 - loss: 0.4319 - acc: 0.8679 - 2ms/step\n", + "step 80/196 - loss: 0.4205 - acc: 0.8664 - 2ms/step\n", + "step 90/196 - loss: 0.5517 - acc: 0.8656 - 2ms/step\n", + "step 100/196 - loss: 0.4190 - acc: 0.8675 - 2ms/step\n", + "step 110/196 - loss: 0.4450 - acc: 0.8682 - 2ms/step\n", + "step 120/196 - loss: 0.3771 - acc: 0.8651 - 2ms/step\n", + "step 130/196 - loss: 0.4033 - acc: 0.8659 - 2ms/step\n", + "step 140/196 - loss: 0.4189 - acc: 0.8667 - 2ms/step\n", + "step 150/196 - loss: 0.4362 - acc: 0.8660 - 2ms/step\n", + "step 160/196 - loss: 0.5045 - acc: 0.8643 - 2ms/step\n", + "step 170/196 - loss: 0.3803 - acc: 0.8651 - 2ms/step\n", + "step 180/196 - loss: 0.3570 - acc: 0.8672 - 2ms/step\n", + "step 190/196 - loss: 0.4183 - acc: 0.8679 - 2ms/step\n", + "step 196/196 - loss: 0.5245 - acc: 0.8683 - 2ms/step\n", + "Eval samples: 6250\n", + "Epoch 5/10\n", + "step 10/586 - loss: 0.3663 - acc: 0.9437 - 7ms/step\n", + "step 20/586 - loss: 0.3953 - acc: 0.9531 - 7ms/step\n", + "step 30/586 - loss: 0.4353 - acc: 0.9448 - 7ms/step\n", + "step 40/586 - loss: 0.4004 - acc: 0.9445 - 7ms/step\n", + "step 50/586 - loss: 0.3962 - acc: 0.9437 - 7ms/step\n", + "step 60/586 - loss: 0.3936 - acc: 0.9453 - 7ms/step\n", + "step 70/586 - loss: 0.3608 - acc: 0.9455 - 6ms/step\n", + "step 80/586 - loss: 0.3816 - acc: 0.9441 - 6ms/step\n", + "step 90/586 - loss: 0.4682 - acc: 0.9437 - 6ms/step\n", + "step 100/586 - loss: 0.3616 - acc: 0.9428 - 6ms/step\n", + "step 110/586 - loss: 0.4110 - acc: 0.9432 - 6ms/step\n", + "step 120/586 - loss: 0.3548 - acc: 0.9437 - 6ms/step\n", + "step 130/586 - loss: 0.3788 - acc: 0.9433 - 6ms/step\n", + "step 140/586 - loss: 0.3626 - acc: 0.9433 - 6ms/step\n", + "step 150/586 - loss: 0.3856 - acc: 0.9435 - 6ms/step\n", + "step 160/586 - loss: 0.4348 - acc: 0.9437 - 6ms/step\n", + "step 170/586 - loss: 0.3337 - acc: 0.9443 - 6ms/step\n", + "step 180/586 - loss: 0.3341 - acc: 0.9439 - 6ms/step\n", + "step 190/586 - loss: 0.3483 - acc: 0.9434 - 6ms/step\n", + "step 200/586 - loss: 0.3253 - acc: 0.9431 - 6ms/step\n", + "step 210/586 - loss: 0.3671 - acc: 0.9418 - 6ms/step\n", + "step 220/586 - loss: 0.3685 - acc: 0.9415 - 6ms/step\n", + "step 230/586 - loss: 0.4182 - acc: 0.9413 - 6ms/step\n", + "step 240/586 - loss: 0.3367 - acc: 0.9410 - 6ms/step\n", + "step 250/586 - loss: 0.4380 - acc: 0.9407 - 6ms/step\n", + "step 260/586 - loss: 0.3579 - acc: 0.9394 - 6ms/step\n", + "step 270/586 - loss: 0.3499 - acc: 0.9388 - 6ms/step\n", + "step 280/586 - loss: 0.4419 - acc: 0.9384 - 6ms/step\n", + "step 290/586 - loss: 0.4185 - acc: 0.9378 - 6ms/step\n", + "step 300/586 - loss: 0.4595 - acc: 0.9375 - 6ms/step\n", + "step 310/586 - loss: 0.3226 - acc: 0.9378 - 6ms/step\n", + "step 320/586 - loss: 0.3661 - acc: 0.9382 - 6ms/step\n", + "step 330/586 - loss: 0.3806 - acc: 0.9383 - 6ms/step\n", + "step 340/586 - loss: 0.4106 - acc: 0.9380 - 6ms/step\n", + "step 350/586 - loss: 0.4062 - acc: 0.9375 - 6ms/step\n", + "step 360/586 - loss: 0.3989 - acc: 0.9375 - 6ms/step\n", + "step 370/586 - loss: 0.3514 - acc: 0.9383 - 6ms/step\n", + "step 380/586 - loss: 0.3183 - acc: 0.9391 - 6ms/step\n", + "step 390/586 - loss: 0.3472 - acc: 0.9395 - 6ms/step\n", + "step 400/586 - loss: 0.3165 - acc: 0.9393 - 6ms/step\n", + "step 410/586 - loss: 0.3192 - acc: 0.9393 - 6ms/step\n", + "step 420/586 - loss: 0.3826 - acc: 0.9394 - 7ms/step\n", + "step 430/586 - loss: 0.3252 - acc: 0.9401 - 7ms/step\n", + "step 440/586 - loss: 0.3815 - acc: 0.9406 - 7ms/step\n", + "step 450/586 - loss: 0.3926 - acc: 0.9408 - 7ms/step\n", + "step 460/586 - loss: 0.4072 - acc: 0.9411 - 7ms/step\n", + "step 470/586 - loss: 0.4134 - acc: 0.9412 - 7ms/step\n", + "step 480/586 - loss: 0.3375 - acc: 0.9413 - 7ms/step\n", + "step 490/586 - loss: 0.3880 - acc: 0.9414 - 7ms/step\n", + "step 500/586 - loss: 0.3885 - acc: 0.9417 - 7ms/step\n", + "step 510/586 - loss: 0.3638 - acc: 0.9417 - 7ms/step\n", + "step 520/586 - loss: 0.4671 - acc: 0.9414 - 7ms/step\n", + "step 530/586 - loss: 0.3618 - acc: 0.9412 - 7ms/step\n", + "step 540/586 - loss: 0.3202 - acc: 0.9409 - 7ms/step\n", + "step 550/586 - loss: 0.3325 - acc: 0.9405 - 7ms/step\n", + "step 560/586 - loss: 0.3969 - acc: 0.9403 - 7ms/step\n", + "step 570/586 - loss: 0.3870 - acc: 0.9399 - 7ms/step\n", + "step 580/586 - loss: 0.3297 - acc: 0.9402 - 7ms/step\n", + "step 586/586 - loss: 0.3533 - acc: 0.9400 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.3991 - acc: 0.8812 - 3ms/step\n", + "step 20/196 - loss: 0.4031 - acc: 0.8875 - 2ms/step\n", + "step 30/196 - loss: 0.5758 - acc: 0.8760 - 2ms/step\n", + "step 40/196 - loss: 0.4588 - acc: 0.8695 - 3ms/step\n", + "step 50/196 - loss: 0.4694 - acc: 0.8669 - 3ms/step\n", + "step 60/196 - loss: 0.4034 - acc: 0.8661 - 3ms/step\n", + "step 70/196 - loss: 0.4236 - acc: 0.8714 - 3ms/step\n", + "step 80/196 - loss: 0.4264 - acc: 0.8703 - 3ms/step\n", + "step 90/196 - loss: 0.5121 - acc: 0.8698 - 3ms/step\n", + "step 100/196 - loss: 0.3963 - acc: 0.8709 - 3ms/step\n", + "step 110/196 - loss: 0.4396 - acc: 0.8716 - 3ms/step\n", + "step 120/196 - loss: 0.3787 - acc: 0.8680 - 3ms/step\n", + "step 130/196 - loss: 0.4081 - acc: 0.8678 - 3ms/step\n", + "step 140/196 - loss: 0.4171 - acc: 0.8676 - 3ms/step\n", + "step 150/196 - loss: 0.4276 - acc: 0.8675 - 3ms/step\n", + "step 160/196 - loss: 0.5145 - acc: 0.8660 - 3ms/step\n", + "step 170/196 - loss: 0.3994 - acc: 0.8664 - 3ms/step\n", + "step 180/196 - loss: 0.3495 - acc: 0.8686 - 3ms/step\n", + "step 190/196 - loss: 0.4370 - acc: 0.8696 - 3ms/step\n", + "step 196/196 - loss: 0.4342 - acc: 0.8706 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 6/10\n", + "step 10/586 - loss: 0.3305 - acc: 0.9656 - 7ms/step\n", + "step 20/586 - loss: 0.3285 - acc: 0.9641 - 7ms/step\n", + "step 30/586 - loss: 0.3835 - acc: 0.9563 - 8ms/step\n", + "step 40/586 - loss: 0.4051 - acc: 0.9492 - 7ms/step\n", + "step 50/586 - loss: 0.3310 - acc: 0.9506 - 7ms/step\n", + "step 60/586 - loss: 0.3157 - acc: 0.9542 - 7ms/step\n", + "step 70/586 - loss: 0.3776 - acc: 0.9540 - 7ms/step\n", + "step 80/586 - loss: 0.4235 - acc: 0.9531 - 8ms/step\n", + "step 90/586 - loss: 0.3765 - acc: 0.9538 - 8ms/step\n", + "step 100/586 - loss: 0.4109 - acc: 0.9537 - 8ms/step\n", + "step 110/586 - loss: 0.3178 - acc: 0.9548 - 8ms/step\n", + "step 120/586 - loss: 0.3332 - acc: 0.9560 - 8ms/step\n", + "step 130/586 - loss: 0.3541 - acc: 0.9560 - 8ms/step\n", + "step 140/586 - loss: 0.4426 - acc: 0.9551 - 8ms/step\n", + "step 150/586 - loss: 0.3988 - acc: 0.9550 - 8ms/step\n", + "step 160/586 - loss: 0.3752 - acc: 0.9553 - 8ms/step\n", + "step 170/586 - loss: 0.3670 - acc: 0.9548 - 8ms/step\n", + "step 180/586 - loss: 0.3524 - acc: 0.9542 - 8ms/step\n", + "step 190/586 - loss: 0.4168 - acc: 0.9531 - 8ms/step\n", + "step 200/586 - loss: 0.4119 - acc: 0.9536 - 8ms/step\n", + "step 210/586 - loss: 0.3779 - acc: 0.9533 - 8ms/step\n", + "step 220/586 - loss: 0.4391 - acc: 0.9536 - 8ms/step\n", + "step 230/586 - loss: 0.3181 - acc: 0.9537 - 8ms/step\n", + "step 240/586 - loss: 0.3546 - acc: 0.9543 - 8ms/step\n", + "step 250/586 - loss: 0.3768 - acc: 0.9545 - 8ms/step\n", + "step 260/586 - loss: 0.3607 - acc: 0.9544 - 7ms/step\n", + "step 270/586 - loss: 0.3783 - acc: 0.9546 - 7ms/step\n", + "step 280/586 - loss: 0.3453 - acc: 0.9542 - 7ms/step\n", + "step 290/586 - loss: 0.3470 - acc: 0.9552 - 7ms/step\n", + "step 300/586 - loss: 0.3719 - acc: 0.9547 - 7ms/step\n", + "step 310/586 - loss: 0.3817 - acc: 0.9542 - 7ms/step\n", + "step 320/586 - loss: 0.3873 - acc: 0.9546 - 7ms/step\n", + "step 330/586 - loss: 0.3214 - acc: 0.9545 - 7ms/step\n", + "step 340/586 - loss: 0.3188 - acc: 0.9546 - 7ms/step\n", + "step 350/586 - loss: 0.4134 - acc: 0.9546 - 7ms/step\n", + "step 360/586 - loss: 0.3154 - acc: 0.9549 - 7ms/step\n", + "step 370/586 - loss: 0.3639 - acc: 0.9550 - 7ms/step\n", + "step 380/586 - loss: 0.3960 - acc: 0.9550 - 7ms/step\n", + "step 390/586 - loss: 0.3466 - acc: 0.9551 - 7ms/step\n", + "step 400/586 - loss: 0.3370 - acc: 0.9555 - 7ms/step\n", + "step 410/586 - loss: 0.3841 - acc: 0.9555 - 7ms/step\n", + "step 420/586 - loss: 0.3942 - acc: 0.9552 - 7ms/step\n", + "step 430/586 - loss: 0.3547 - acc: 0.9551 - 7ms/step\n", + "step 440/586 - loss: 0.3170 - acc: 0.9553 - 7ms/step\n", + "step 450/586 - loss: 0.3266 - acc: 0.9556 - 7ms/step\n", + "step 460/586 - loss: 0.3429 - acc: 0.9553 - 7ms/step\n", + "step 470/586 - loss: 0.3164 - acc: 0.9555 - 7ms/step\n", + "step 480/586 - loss: 0.3724 - acc: 0.9555 - 7ms/step\n", + "step 490/586 - loss: 0.3533 - acc: 0.9554 - 7ms/step\n", + "step 500/586 - loss: 0.4149 - acc: 0.9556 - 7ms/step\n", + "step 510/586 - loss: 0.3577 - acc: 0.9552 - 7ms/step\n", + "step 520/586 - loss: 0.3712 - acc: 0.9553 - 7ms/step\n", + "step 530/586 - loss: 0.3233 - acc: 0.9555 - 7ms/step\n", + "step 540/586 - loss: 0.3177 - acc: 0.9556 - 7ms/step\n", + "step 550/586 - loss: 0.3508 - acc: 0.9557 - 7ms/step\n", + "step 560/586 - loss: 0.3778 - acc: 0.9553 - 7ms/step\n", + "step 570/586 - loss: 0.3157 - acc: 0.9552 - 7ms/step\n", + "step 580/586 - loss: 0.3832 - acc: 0.9551 - 7ms/step\n", + "step 586/586 - loss: 0.3516 - acc: 0.9552 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.3740 - acc: 0.8875 - 3ms/step\n", + "step 20/196 - loss: 0.3935 - acc: 0.8922 - 3ms/step\n", + "step 30/196 - loss: 0.5860 - acc: 0.8771 - 3ms/step\n", + "step 40/196 - loss: 0.4778 - acc: 0.8719 - 3ms/step\n", + "step 50/196 - loss: 0.4675 - acc: 0.8669 - 3ms/step\n", + "step 60/196 - loss: 0.3974 - acc: 0.8625 - 3ms/step\n", + "step 70/196 - loss: 0.4264 - acc: 0.8670 - 3ms/step\n", + "step 80/196 - loss: 0.4237 - acc: 0.8668 - 3ms/step\n", + "step 90/196 - loss: 0.5286 - acc: 0.8660 - 3ms/step\n", + "step 100/196 - loss: 0.3980 - acc: 0.8669 - 3ms/step\n", + "step 110/196 - loss: 0.4362 - acc: 0.8676 - 3ms/step\n", + "step 120/196 - loss: 0.3779 - acc: 0.8638 - 3ms/step\n", + "step 130/196 - loss: 0.4090 - acc: 0.8644 - 3ms/step\n", + "step 140/196 - loss: 0.4323 - acc: 0.8652 - 3ms/step\n", + "step 150/196 - loss: 0.4067 - acc: 0.8654 - 3ms/step\n", + "step 160/196 - loss: 0.5107 - acc: 0.8637 - 3ms/step\n", + "step 170/196 - loss: 0.4058 - acc: 0.8649 - 3ms/step\n", + "step 180/196 - loss: 0.3519 - acc: 0.8663 - 3ms/step\n", + "step 190/196 - loss: 0.4454 - acc: 0.8663 - 3ms/step\n", + "step 196/196 - loss: 0.4457 - acc: 0.8672 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 7/10\n", + "step 10/586 - loss: 0.4362 - acc: 0.9437 - 7ms/step\n", + "step 20/586 - loss: 0.3194 - acc: 0.9469 - 7ms/step\n", + "step 30/586 - loss: 0.4111 - acc: 0.9510 - 7ms/step\n", + "step 40/586 - loss: 0.3341 - acc: 0.9531 - 7ms/step\n", + "step 50/586 - loss: 0.3775 - acc: 0.9563 - 7ms/step\n", + "step 60/586 - loss: 0.3455 - acc: 0.9578 - 7ms/step\n", + "step 70/586 - loss: 0.3955 - acc: 0.9563 - 7ms/step\n", + "step 80/586 - loss: 0.3743 - acc: 0.9586 - 7ms/step\n", + "step 90/586 - loss: 0.3200 - acc: 0.9587 - 7ms/step\n", + "step 100/586 - loss: 0.3480 - acc: 0.9578 - 7ms/step\n", + "step 110/586 - loss: 0.3540 - acc: 0.9594 - 7ms/step\n", + "step 120/586 - loss: 0.3137 - acc: 0.9609 - 7ms/step\n", + "step 130/586 - loss: 0.3789 - acc: 0.9606 - 7ms/step\n", + "step 140/586 - loss: 0.3223 - acc: 0.9603 - 7ms/step\n", + "step 150/586 - loss: 0.3147 - acc: 0.9608 - 7ms/step\n", + "step 160/586 - loss: 0.3199 - acc: 0.9617 - 7ms/step\n", + "step 170/586 - loss: 0.3418 - acc: 0.9625 - 7ms/step\n", + "step 180/586 - loss: 0.3225 - acc: 0.9634 - 8ms/step\n", + "step 190/586 - loss: 0.3235 - acc: 0.9645 - 8ms/step\n", + "step 200/586 - loss: 0.3151 - acc: 0.9655 - 8ms/step\n", + "step 210/586 - loss: 0.3149 - acc: 0.9658 - 8ms/step\n", + "step 220/586 - loss: 0.3457 - acc: 0.9659 - 8ms/step\n", + "step 230/586 - loss: 0.3459 - acc: 0.9662 - 8ms/step\n", + "step 240/586 - loss: 0.3166 - acc: 0.9664 - 8ms/step\n", + "step 250/586 - loss: 0.3819 - acc: 0.9661 - 8ms/step\n", + "step 260/586 - loss: 0.3473 - acc: 0.9660 - 8ms/step\n", + "step 270/586 - loss: 0.3214 - acc: 0.9661 - 8ms/step\n", + "step 280/586 - loss: 0.4032 - acc: 0.9660 - 8ms/step\n", + "step 290/586 - loss: 0.3486 - acc: 0.9659 - 8ms/step\n", + "step 300/586 - loss: 0.3309 - acc: 0.9663 - 8ms/step\n", + "step 310/586 - loss: 0.3581 - acc: 0.9664 - 7ms/step\n", + "step 320/586 - loss: 0.4081 - acc: 0.9657 - 7ms/step\n", + "step 330/586 - loss: 0.3550 - acc: 0.9653 - 7ms/step\n", + "step 340/586 - loss: 0.3379 - acc: 0.9657 - 7ms/step\n", + "step 350/586 - loss: 0.3423 - acc: 0.9652 - 7ms/step\n", + "step 360/586 - loss: 0.3774 - acc: 0.9649 - 7ms/step\n", + "step 370/586 - loss: 0.3143 - acc: 0.9651 - 7ms/step\n", + "step 380/586 - loss: 0.3399 - acc: 0.9651 - 7ms/step\n", + "step 390/586 - loss: 0.3416 - acc: 0.9655 - 7ms/step\n", + "step 400/586 - loss: 0.3877 - acc: 0.9652 - 7ms/step\n", + "step 410/586 - loss: 0.4009 - acc: 0.9649 - 7ms/step\n", + "step 420/586 - loss: 0.3149 - acc: 0.9647 - 7ms/step\n", + "step 430/586 - loss: 0.3817 - acc: 0.9646 - 7ms/step\n", + "step 440/586 - loss: 0.3468 - acc: 0.9649 - 7ms/step\n", + "step 450/586 - loss: 0.3474 - acc: 0.9650 - 7ms/step\n", + "step 460/586 - loss: 0.3547 - acc: 0.9649 - 7ms/step\n", + "step 470/586 - loss: 0.3495 - acc: 0.9651 - 7ms/step\n", + "step 480/586 - loss: 0.3674 - acc: 0.9647 - 7ms/step\n", + "step 490/586 - loss: 0.3634 - acc: 0.9647 - 7ms/step\n", + "step 500/586 - loss: 0.3542 - acc: 0.9647 - 7ms/step\n", + "step 510/586 - loss: 0.3150 - acc: 0.9650 - 7ms/step\n", + "step 520/586 - loss: 0.3141 - acc: 0.9652 - 7ms/step\n", + "step 530/586 - loss: 0.3235 - acc: 0.9652 - 7ms/step\n", + "step 540/586 - loss: 0.3867 - acc: 0.9653 - 7ms/step\n", + "step 550/586 - loss: 0.3493 - acc: 0.9655 - 7ms/step\n", + "step 560/586 - loss: 0.4191 - acc: 0.9656 - 7ms/step\n", + "step 570/586 - loss: 0.3169 - acc: 0.9650 - 7ms/step\n", + "step 580/586 - loss: 0.3171 - acc: 0.9649 - 7ms/step\n", + "step 586/586 - loss: 0.3298 - acc: 0.9648 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4102 - acc: 0.8781 - 3ms/step\n", + "step 20/196 - loss: 0.3831 - acc: 0.8906 - 3ms/step\n", + "step 30/196 - loss: 0.5540 - acc: 0.8802 - 3ms/step\n", + "step 40/196 - loss: 0.5060 - acc: 0.8727 - 2ms/step\n", + "step 50/196 - loss: 0.4351 - acc: 0.8750 - 2ms/step\n", + "step 60/196 - loss: 0.3830 - acc: 0.8698 - 2ms/step\n", + "step 70/196 - loss: 0.4603 - acc: 0.8723 - 2ms/step\n", + "step 80/196 - loss: 0.4188 - acc: 0.8703 - 2ms/step\n", + "step 90/196 - loss: 0.5685 - acc: 0.8691 - 2ms/step\n", + "step 100/196 - loss: 0.4086 - acc: 0.8719 - 2ms/step\n", + "step 110/196 - loss: 0.4628 - acc: 0.8722 - 3ms/step\n", + "step 120/196 - loss: 0.3791 - acc: 0.8674 - 3ms/step\n", + "step 130/196 - loss: 0.4087 - acc: 0.8673 - 3ms/step\n", + "step 140/196 - loss: 0.4109 - acc: 0.8688 - 3ms/step\n", + "step 150/196 - loss: 0.4144 - acc: 0.8688 - 3ms/step\n", + "step 160/196 - loss: 0.5291 - acc: 0.8666 - 3ms/step\n", + "step 170/196 - loss: 0.4071 - acc: 0.8678 - 3ms/step\n", + "step 180/196 - loss: 0.3402 - acc: 0.8703 - 3ms/step\n", + "step 190/196 - loss: 0.4466 - acc: 0.8707 - 3ms/step\n", + "step 196/196 - loss: 0.4286 - acc: 0.8712 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 8/10\n", + "step 10/586 - loss: 0.3689 - acc: 0.9531 - 7ms/step\n", + "step 20/586 - loss: 0.3800 - acc: 0.9531 - 7ms/step\n", + "step 30/586 - loss: 0.3609 - acc: 0.9583 - 7ms/step\n", + "step 40/586 - loss: 0.3177 - acc: 0.9586 - 7ms/step\n", + "step 50/586 - loss: 0.4016 - acc: 0.9594 - 7ms/step\n", + "step 60/586 - loss: 0.3537 - acc: 0.9609 - 8ms/step\n", + "step 70/586 - loss: 0.3203 - acc: 0.9616 - 7ms/step\n", + "step 80/586 - loss: 0.4411 - acc: 0.9609 - 7ms/step\n", + "step 90/586 - loss: 0.3150 - acc: 0.9639 - 7ms/step\n", + "step 100/586 - loss: 0.3471 - acc: 0.9641 - 7ms/step\n", + "step 110/586 - loss: 0.3501 - acc: 0.9642 - 7ms/step\n", + "step 120/586 - loss: 0.3169 - acc: 0.9648 - 7ms/step\n", + "step 130/586 - loss: 0.3198 - acc: 0.9656 - 7ms/step\n", + "step 140/586 - loss: 0.3776 - acc: 0.9658 - 7ms/step\n", + "step 150/586 - loss: 0.3153 - acc: 0.9660 - 7ms/step\n", + "step 160/586 - loss: 0.3440 - acc: 0.9664 - 7ms/step\n", + "step 170/586 - loss: 0.3279 - acc: 0.9675 - 7ms/step\n", + "step 180/586 - loss: 0.3984 - acc: 0.9674 - 7ms/step\n", + "step 190/586 - loss: 0.3500 - acc: 0.9666 - 7ms/step\n", + "step 200/586 - loss: 0.3219 - acc: 0.9675 - 7ms/step\n", + "step 210/586 - loss: 0.3145 - acc: 0.9683 - 7ms/step\n", + "step 220/586 - loss: 0.3149 - acc: 0.9688 - 7ms/step\n", + "step 230/586 - loss: 0.3151 - acc: 0.9688 - 7ms/step\n", + "step 240/586 - loss: 0.3338 - acc: 0.9689 - 7ms/step\n", + "step 250/586 - loss: 0.3511 - acc: 0.9692 - 7ms/step\n", + "step 260/586 - loss: 0.3172 - acc: 0.9692 - 7ms/step\n", + "step 270/586 - loss: 0.3563 - acc: 0.9691 - 7ms/step\n", + "step 280/586 - loss: 0.3349 - acc: 0.9693 - 7ms/step\n", + "step 290/586 - loss: 0.3152 - acc: 0.9692 - 7ms/step\n", + "step 300/586 - loss: 0.3159 - acc: 0.9694 - 7ms/step\n", + "step 310/586 - loss: 0.3461 - acc: 0.9695 - 7ms/step\n", + "step 320/586 - loss: 0.3141 - acc: 0.9696 - 7ms/step\n", + "step 330/586 - loss: 0.3141 - acc: 0.9694 - 7ms/step\n", + "step 340/586 - loss: 0.3457 - acc: 0.9698 - 7ms/step\n", + "step 350/586 - loss: 0.4385 - acc: 0.9690 - 7ms/step\n", + "step 360/586 - loss: 0.3157 - acc: 0.9684 - 7ms/step\n", + "step 370/586 - loss: 0.3917 - acc: 0.9679 - 7ms/step\n", + "step 380/586 - loss: 0.3459 - acc: 0.9679 - 7ms/step\n", + "step 390/586 - loss: 0.3501 - acc: 0.9683 - 7ms/step\n", + "step 400/586 - loss: 0.3754 - acc: 0.9684 - 7ms/step\n", + "step 410/586 - loss: 0.3550 - acc: 0.9685 - 7ms/step\n", + "step 420/586 - loss: 0.3843 - acc: 0.9689 - 7ms/step\n", + "step 430/586 - loss: 0.3151 - acc: 0.9691 - 7ms/step\n", + "step 440/586 - loss: 0.3450 - acc: 0.9692 - 7ms/step\n", + "step 450/586 - loss: 0.3548 - acc: 0.9696 - 7ms/step\n", + "step 460/586 - loss: 0.3773 - acc: 0.9695 - 7ms/step\n", + "step 470/586 - loss: 0.3140 - acc: 0.9694 - 7ms/step\n", + "step 480/586 - loss: 0.3399 - acc: 0.9692 - 7ms/step\n", + "step 490/586 - loss: 0.3873 - acc: 0.9690 - 7ms/step\n", + "step 500/586 - loss: 0.3411 - acc: 0.9691 - 7ms/step\n", + "step 510/586 - loss: 0.3454 - acc: 0.9691 - 7ms/step\n", + "step 520/586 - loss: 0.3436 - acc: 0.9690 - 7ms/step\n", + "step 530/586 - loss: 0.3766 - acc: 0.9688 - 7ms/step\n", + "step 540/586 - loss: 0.3471 - acc: 0.9685 - 7ms/step\n", + "step 550/586 - loss: 0.3464 - acc: 0.9688 - 7ms/step\n", + "step 560/586 - loss: 0.3193 - acc: 0.9689 - 7ms/step\n", + "step 570/586 - loss: 0.3206 - acc: 0.9688 - 7ms/step\n", + "step 580/586 - loss: 0.4221 - acc: 0.9688 - 7ms/step\n", + "step 586/586 - loss: 0.4137 - acc: 0.9687 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4021 - acc: 0.8750 - 3ms/step\n", + "step 20/196 - loss: 0.3832 - acc: 0.8938 - 3ms/step\n", + "step 30/196 - loss: 0.5667 - acc: 0.8802 - 3ms/step\n", + "step 40/196 - loss: 0.4844 - acc: 0.8727 - 3ms/step\n", + "step 50/196 - loss: 0.4485 - acc: 0.8712 - 3ms/step\n", + "step 60/196 - loss: 0.3879 - acc: 0.8661 - 3ms/step\n", + "step 70/196 - loss: 0.4494 - acc: 0.8719 - 3ms/step\n", + "step 80/196 - loss: 0.4179 - acc: 0.8707 - 3ms/step\n", + "step 90/196 - loss: 0.5252 - acc: 0.8712 - 3ms/step\n", + "step 100/196 - loss: 0.3908 - acc: 0.8728 - 3ms/step\n", + "step 110/196 - loss: 0.4374 - acc: 0.8730 - 2ms/step\n", + "step 120/196 - loss: 0.3779 - acc: 0.8685 - 2ms/step\n", + "step 130/196 - loss: 0.4083 - acc: 0.8680 - 2ms/step\n", + "step 140/196 - loss: 0.4196 - acc: 0.8688 - 2ms/step\n", + "step 150/196 - loss: 0.3966 - acc: 0.8683 - 2ms/step\n", + "step 160/196 - loss: 0.5057 - acc: 0.8670 - 2ms/step\n", + "step 170/196 - loss: 0.3764 - acc: 0.8676 - 2ms/step\n", + "step 180/196 - loss: 0.3452 - acc: 0.8693 - 2ms/step\n", + "step 190/196 - loss: 0.4252 - acc: 0.8689 - 2ms/step\n", + "step 196/196 - loss: 0.4172 - acc: 0.8696 - 2ms/step\n", + "Eval samples: 6250\n", + "Epoch 9/10\n", + "step 10/586 - loss: 0.3192 - acc: 0.9875 - 7ms/step\n", + "step 20/586 - loss: 0.3457 - acc: 0.9844 - 8ms/step\n", + "step 30/586 - loss: 0.3765 - acc: 0.9771 - 7ms/step\n", + "step 40/586 - loss: 0.3740 - acc: 0.9680 - 7ms/step\n", + "step 50/586 - loss: 0.3542 - acc: 0.9656 - 7ms/step\n", + "step 60/586 - loss: 0.3400 - acc: 0.9625 - 7ms/step\n", + "step 70/586 - loss: 0.3535 - acc: 0.9625 - 7ms/step\n", + "step 80/586 - loss: 0.3456 - acc: 0.9645 - 7ms/step\n", + "step 90/586 - loss: 0.3141 - acc: 0.9663 - 7ms/step\n", + "step 100/586 - loss: 0.3465 - acc: 0.9650 - 7ms/step\n", + "step 110/586 - loss: 0.3315 - acc: 0.9645 - 7ms/step\n", + "step 120/586 - loss: 0.3145 - acc: 0.9659 - 7ms/step\n", + "step 130/586 - loss: 0.3475 - acc: 0.9668 - 7ms/step\n", + "step 140/586 - loss: 0.3171 - acc: 0.9683 - 7ms/step\n", + "step 150/586 - loss: 0.3462 - acc: 0.9681 - 7ms/step\n", + "step 160/586 - loss: 0.3492 - acc: 0.9682 - 7ms/step\n", + "step 170/586 - loss: 0.3475 - acc: 0.9689 - 7ms/step\n", + "step 180/586 - loss: 0.3466 - acc: 0.9694 - 7ms/step\n", + "step 190/586 - loss: 0.4103 - acc: 0.9696 - 7ms/step\n", + "step 200/586 - loss: 0.3672 - acc: 0.9700 - 7ms/step\n", + "step 210/586 - loss: 0.4100 - acc: 0.9695 - 7ms/step\n", + "step 220/586 - loss: 0.4084 - acc: 0.9699 - 7ms/step\n", + "step 230/586 - loss: 0.3141 - acc: 0.9707 - 7ms/step\n", + "step 240/586 - loss: 0.3450 - acc: 0.9708 - 7ms/step\n", + "step 250/586 - loss: 0.3462 - acc: 0.9705 - 7ms/step\n", + "step 260/586 - loss: 0.3178 - acc: 0.9706 - 7ms/step\n", + "step 270/586 - loss: 0.3451 - acc: 0.9703 - 7ms/step\n", + "step 280/586 - loss: 0.3493 - acc: 0.9705 - 7ms/step\n", + "step 290/586 - loss: 0.3174 - acc: 0.9711 - 7ms/step\n", + "step 300/586 - loss: 0.3171 - acc: 0.9716 - 7ms/step\n", + "step 310/586 - loss: 0.3478 - acc: 0.9720 - 7ms/step\n", + "step 320/586 - loss: 0.3220 - acc: 0.9723 - 7ms/step\n", + "step 330/586 - loss: 0.3139 - acc: 0.9724 - 7ms/step\n", + "step 340/586 - loss: 0.3137 - acc: 0.9730 - 7ms/step\n", + "step 350/586 - loss: 0.4082 - acc: 0.9728 - 7ms/step\n", + "step 360/586 - loss: 0.3447 - acc: 0.9727 - 7ms/step\n", + "step 370/586 - loss: 0.3136 - acc: 0.9728 - 7ms/step\n", + "step 380/586 - loss: 0.3284 - acc: 0.9728 - 7ms/step\n", + "step 390/586 - loss: 0.4076 - acc: 0.9726 - 7ms/step\n", + "step 400/586 - loss: 0.3646 - acc: 0.9726 - 7ms/step\n", + "step 410/586 - loss: 0.3137 - acc: 0.9723 - 7ms/step\n", + "step 420/586 - loss: 0.3452 - acc: 0.9724 - 7ms/step\n", + "step 430/586 - loss: 0.3210 - acc: 0.9720 - 7ms/step\n", + "step 440/586 - loss: 0.3764 - acc: 0.9719 - 7ms/step\n", + "step 450/586 - loss: 0.3449 - acc: 0.9721 - 7ms/step\n", + "step 460/586 - loss: 0.3808 - acc: 0.9724 - 7ms/step\n", + "step 470/586 - loss: 0.3767 - acc: 0.9723 - 7ms/step\n", + "step 480/586 - loss: 0.3582 - acc: 0.9720 - 7ms/step\n", + "step 490/586 - loss: 0.4074 - acc: 0.9721 - 7ms/step\n", + "step 500/586 - loss: 0.3281 - acc: 0.9724 - 7ms/step\n", + "step 510/586 - loss: 0.3197 - acc: 0.9725 - 7ms/step\n", + "step 520/586 - loss: 0.3449 - acc: 0.9725 - 7ms/step\n", + "step 530/586 - loss: 0.3772 - acc: 0.9723 - 7ms/step\n", + "step 540/586 - loss: 0.3460 - acc: 0.9723 - 7ms/step\n", + "step 550/586 - loss: 0.3758 - acc: 0.9719 - 7ms/step\n", + "step 560/586 - loss: 0.3837 - acc: 0.9720 - 7ms/step\n", + "step 570/586 - loss: 0.3185 - acc: 0.9718 - 7ms/step\n", + "step 580/586 - loss: 0.3173 - acc: 0.9720 - 7ms/step\n", + "step 586/586 - loss: 0.3142 - acc: 0.9721 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4118 - acc: 0.8562 - 3ms/step\n", + "step 20/196 - loss: 0.4136 - acc: 0.8688 - 3ms/step\n", + "step 30/196 - loss: 0.5431 - acc: 0.8729 - 3ms/step\n", + "step 40/196 - loss: 0.4878 - acc: 0.8641 - 2ms/step\n", + "step 50/196 - loss: 0.4139 - acc: 0.8675 - 2ms/step\n", + "step 60/196 - loss: 0.3872 - acc: 0.8646 - 2ms/step\n", + "step 70/196 - loss: 0.4269 - acc: 0.8692 - 2ms/step\n", + "step 80/196 - loss: 0.4665 - acc: 0.8668 - 2ms/step\n", + "step 90/196 - loss: 0.5964 - acc: 0.8670 - 2ms/step\n", + "step 100/196 - loss: 0.4225 - acc: 0.8709 - 2ms/step\n", + "step 110/196 - loss: 0.4720 - acc: 0.8696 - 2ms/step\n", + "step 120/196 - loss: 0.3814 - acc: 0.8635 - 2ms/step\n", + "step 130/196 - loss: 0.4242 - acc: 0.8635 - 3ms/step\n", + "step 140/196 - loss: 0.3902 - acc: 0.8661 - 3ms/step\n", + "step 150/196 - loss: 0.4303 - acc: 0.8648 - 3ms/step\n", + "step 160/196 - loss: 0.5004 - acc: 0.8633 - 3ms/step\n", + "step 170/196 - loss: 0.4446 - acc: 0.8632 - 3ms/step\n", + "step 180/196 - loss: 0.3417 - acc: 0.8656 - 3ms/step\n", + "step 190/196 - loss: 0.4667 - acc: 0.8660 - 3ms/step\n", + "step 196/196 - loss: 0.4134 - acc: 0.8664 - 3ms/step\n", + "Eval samples: 6250\n", + "Epoch 10/10\n", + "step 10/586 - loss: 0.3144 - acc: 0.9781 - 7ms/step\n", + "step 20/586 - loss: 0.3819 - acc: 0.9719 - 7ms/step\n", + "step 30/586 - loss: 0.3147 - acc: 0.9698 - 7ms/step\n", + "step 40/586 - loss: 0.3139 - acc: 0.9727 - 7ms/step\n", + "step 50/586 - loss: 0.3788 - acc: 0.9738 - 7ms/step\n", + "step 60/586 - loss: 0.3472 - acc: 0.9724 - 7ms/step\n", + "step 70/586 - loss: 0.3139 - acc: 0.9714 - 7ms/step\n", + "step 80/586 - loss: 0.3453 - acc: 0.9727 - 7ms/step\n", + "step 90/586 - loss: 0.3769 - acc: 0.9729 - 7ms/step\n", + "step 100/586 - loss: 0.3460 - acc: 0.9734 - 7ms/step\n", + "step 110/586 - loss: 0.3137 - acc: 0.9727 - 7ms/step\n", + "step 120/586 - loss: 0.3137 - acc: 0.9721 - 7ms/step\n", + "step 130/586 - loss: 0.3458 - acc: 0.9724 - 7ms/step\n", + "step 140/586 - loss: 0.3453 - acc: 0.9732 - 7ms/step\n", + "step 150/586 - loss: 0.3457 - acc: 0.9729 - 7ms/step\n", + "step 160/586 - loss: 0.3145 - acc: 0.9740 - 7ms/step\n", + "step 170/586 - loss: 0.3614 - acc: 0.9732 - 7ms/step\n", + "step 180/586 - loss: 0.3550 - acc: 0.9731 - 7ms/step\n", + "step 190/586 - loss: 0.3135 - acc: 0.9735 - 7ms/step\n", + "step 200/586 - loss: 0.3638 - acc: 0.9739 - 7ms/step\n", + "step 210/586 - loss: 0.3447 - acc: 0.9737 - 7ms/step\n", + "step 220/586 - loss: 0.3136 - acc: 0.9734 - 7ms/step\n", + "step 230/586 - loss: 0.3480 - acc: 0.9735 - 7ms/step\n", + "step 240/586 - loss: 0.3144 - acc: 0.9734 - 7ms/step\n", + "step 250/586 - loss: 0.3147 - acc: 0.9740 - 7ms/step\n", + "step 260/586 - loss: 0.3135 - acc: 0.9742 - 7ms/step\n", + "step 270/586 - loss: 0.3768 - acc: 0.9748 - 7ms/step\n", + "step 280/586 - loss: 0.3455 - acc: 0.9749 - 7ms/step\n", + "step 290/586 - loss: 0.3147 - acc: 0.9748 - 7ms/step\n", + "step 300/586 - loss: 0.3765 - acc: 0.9745 - 7ms/step\n", + "step 310/586 - loss: 0.3761 - acc: 0.9742 - 7ms/step\n", + "step 320/586 - loss: 0.3487 - acc: 0.9739 - 7ms/step\n", + "step 330/586 - loss: 0.3621 - acc: 0.9739 - 7ms/step\n", + "step 340/586 - loss: 0.3145 - acc: 0.9738 - 7ms/step\n", + "step 350/586 - loss: 0.3135 - acc: 0.9738 - 7ms/step\n", + "step 360/586 - loss: 0.3454 - acc: 0.9740 - 7ms/step\n", + "step 370/586 - loss: 0.3145 - acc: 0.9744 - 7ms/step\n", + "step 380/586 - loss: 0.3454 - acc: 0.9745 - 7ms/step\n", + "step 390/586 - loss: 0.3462 - acc: 0.9747 - 7ms/step\n", + "step 400/586 - loss: 0.3152 - acc: 0.9750 - 7ms/step\n", + "step 410/586 - loss: 0.3473 - acc: 0.9753 - 7ms/step\n", + "step 420/586 - loss: 0.3449 - acc: 0.9754 - 7ms/step\n", + "step 430/586 - loss: 0.3154 - acc: 0.9757 - 7ms/step\n", + "step 440/586 - loss: 0.3457 - acc: 0.9759 - 7ms/step\n", + "step 450/586 - loss: 0.3457 - acc: 0.9757 - 7ms/step\n", + "step 460/586 - loss: 0.3447 - acc: 0.9757 - 7ms/step\n", + "step 470/586 - loss: 0.3137 - acc: 0.9757 - 7ms/step\n", + "step 480/586 - loss: 0.3139 - acc: 0.9759 - 7ms/step\n", + "step 490/586 - loss: 0.3473 - acc: 0.9760 - 7ms/step\n", + "step 500/586 - loss: 0.3155 - acc: 0.9759 - 7ms/step\n", + "step 510/586 - loss: 0.3760 - acc: 0.9757 - 7ms/step\n", + "step 520/586 - loss: 0.3452 - acc: 0.9755 - 7ms/step\n", + "step 530/586 - loss: 0.3139 - acc: 0.9756 - 7ms/step\n", + "step 540/586 - loss: 0.3139 - acc: 0.9756 - 7ms/step\n", + "step 550/586 - loss: 0.3143 - acc: 0.9757 - 7ms/step\n", + "step 560/586 - loss: 0.3144 - acc: 0.9759 - 7ms/step\n", + "step 570/586 - loss: 0.3450 - acc: 0.9759 - 7ms/step\n", + "step 580/586 - loss: 0.3245 - acc: 0.9758 - 7ms/step\n", + "step 586/586 - loss: 0.3829 - acc: 0.9756 - 7ms/step\n", + "Eval begin...\n", + "step 10/196 - loss: 0.4100 - acc: 0.8531 - 5ms/step\n", + "step 20/196 - loss: 0.4061 - acc: 0.8703 - 4ms/step\n", + "step 30/196 - loss: 0.5566 - acc: 0.8719 - 3ms/step\n", + "step 40/196 - loss: 0.4805 - acc: 0.8656 - 3ms/step\n", + "step 50/196 - loss: 0.4235 - acc: 0.8662 - 3ms/step\n", + "step 60/196 - loss: 0.4023 - acc: 0.8620 - 3ms/step\n", + "step 70/196 - loss: 0.4327 - acc: 0.8656 - 3ms/step\n", + "step 80/196 - loss: 0.4856 - acc: 0.8625 - 3ms/step\n", + "step 90/196 - loss: 0.5713 - acc: 0.8639 - 3ms/step\n", + "step 100/196 - loss: 0.3963 - acc: 0.8678 - 3ms/step\n", + "step 110/196 - loss: 0.4678 - acc: 0.8676 - 3ms/step\n", + "step 120/196 - loss: 0.4025 - acc: 0.8625 - 3ms/step\n", + "step 130/196 - loss: 0.4336 - acc: 0.8627 - 3ms/step\n", + "step 140/196 - loss: 0.3946 - acc: 0.8652 - 3ms/step\n", + "step 150/196 - loss: 0.4038 - acc: 0.8646 - 3ms/step\n", + "step 160/196 - loss: 0.5087 - acc: 0.8633 - 3ms/step\n", + "step 170/196 - loss: 0.4656 - acc: 0.8638 - 3ms/step\n", + "step 180/196 - loss: 0.3433 - acc: 0.8660 - 3ms/step\n", + "step 190/196 - loss: 0.4656 - acc: 0.8663 - 3ms/step\n", + "step 196/196 - loss: 0.4132 - acc: 0.8672 - 3ms/step\n", + "Eval samples: 6250\n" + ] + } + ], + "source": [ + "class DataReader(Dataset):\r\n", + " def __init__(self, input, label, length):\r\n", + " self.data = list(vectorizer(input, label, length=length))\r\n", + "\r\n", + " def __getitem__(self, idx):\r\n", + " return self.data[idx]\r\n", + "\r\n", + " def __len__(self):\r\n", + " return len(self.data)\r\n", + "\r\n", + "\r\n", + "# 指定训练设备\r\n", + "device = pd.set_device('gpu') # 可选:cpu\r\n", + "\r\n", + "# 开启动态图模式\r\n", + "pd.disable_static(device)\r\n", + "\r\n", + "# 定义输入格式\r\n", + "input_form = pd.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", + "label_form = pd.static.InputSpec(shape=[None, 1], dtype='int64', name='label')\r\n", + "\r\n", + "model = pd.Model(sim_model, input_form, label_form)\r\n", + "model.prepare(optimizer=pd.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),\r\n", + " loss=pd.nn.loss.CrossEntropyLoss(),\r\n", + " metrics=pd.metric.Accuracy())\r\n", + "\r\n", + "# 分割训练集和验证集\r\n", + "eval_length = int(len(train_x) * 1/4)\r\n", + "model.fit(train_data=DataReader(train_x[:-eval_length], train_y[:-eval_length], length),\r\n", + " eval_data=DataReader(train_x[-eval_length:], train_y[-eval_length:], length),\r\n", + " batch_size=32, epochs=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 评估效果并用模型预测" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval begin...\n", + "step 10/782 - loss: 0.4515 - acc: 0.8531 - 3ms/step\n", + "step 20/782 - loss: 0.5053 - acc: 0.8656 - 3ms/step\n", + "step 30/782 - loss: 0.4896 - acc: 0.8406 - 3ms/step\n", + "step 40/782 - loss: 0.3849 - acc: 0.8469 - 3ms/step\n", + "step 50/782 - loss: 0.5705 - acc: 0.8331 - 3ms/step\n", + "step 60/782 - loss: 0.3480 - acc: 0.8370 - 3ms/step\n", + "step 70/782 - loss: 0.3403 - acc: 0.8460 - 3ms/step\n", + "step 80/782 - loss: 0.3370 - acc: 0.8473 - 3ms/step\n", + "step 90/782 - loss: 0.5180 - acc: 0.8462 - 3ms/step\n", + "step 100/782 - loss: 0.4266 - acc: 0.8481 - 3ms/step\n", + "step 110/782 - loss: 0.4605 - acc: 0.8486 - 3ms/step\n", + "step 120/782 - loss: 0.3836 - acc: 0.8477 - 3ms/step\n", + "step 130/782 - loss: 0.4657 - acc: 0.8474 - 3ms/step\n", + "step 140/782 - loss: 0.4203 - acc: 0.8462 - 3ms/step\n", + "step 150/782 - loss: 0.4735 - acc: 0.8408 - 3ms/step\n", + "step 160/782 - loss: 0.4959 - acc: 0.8412 - 3ms/step\n", + "step 170/782 - loss: 0.3490 - acc: 0.8419 - 3ms/step\n", + "step 180/782 - loss: 0.6037 - acc: 0.8415 - 3ms/step\n", + "step 190/782 - loss: 0.4110 - acc: 0.8416 - 3ms/step\n", + "step 200/782 - loss: 0.5318 - acc: 0.8430 - 3ms/step\n", + "step 210/782 - loss: 0.4332 - acc: 0.8449 - 3ms/step\n", + "step 220/782 - loss: 0.6212 - acc: 0.8447 - 3ms/step\n", + "step 230/782 - loss: 0.4884 - acc: 0.8443 - 3ms/step\n", + "step 240/782 - loss: 0.3646 - acc: 0.8434 - 3ms/step\n", + "step 250/782 - loss: 0.4735 - acc: 0.8446 - 3ms/step\n", + "step 260/782 - loss: 0.4272 - acc: 0.8460 - 3ms/step\n", + "step 270/782 - loss: 0.5258 - acc: 0.8453 - 3ms/step\n", + "step 280/782 - loss: 0.4614 - acc: 0.8449 - 3ms/step\n", + "step 290/782 - loss: 0.4773 - acc: 0.8454 - 3ms/step\n", + "step 300/782 - loss: 0.5187 - acc: 0.8441 - 3ms/step\n", + "step 310/782 - loss: 0.4952 - acc: 0.8431 - 3ms/step\n", + "step 320/782 - loss: 0.3959 - acc: 0.8435 - 3ms/step\n", + "step 330/782 - loss: 0.4840 - acc: 0.8437 - 3ms/step\n", + "step 340/782 - loss: 0.3650 - acc: 0.8441 - 3ms/step\n", + "step 350/782 - loss: 0.4842 - acc: 0.8450 - 3ms/step\n", + "step 360/782 - loss: 0.4866 - acc: 0.8444 - 3ms/step\n", + "step 370/782 - loss: 0.4882 - acc: 0.8454 - 3ms/step\n", + "step 380/782 - loss: 0.4428 - acc: 0.8434 - 3ms/step\n", + "step 390/782 - loss: 0.4084 - acc: 0.8430 - 3ms/step\n", + "step 400/782 - loss: 0.4584 - acc: 0.8433 - 3ms/step\n", + "step 410/782 - loss: 0.5239 - acc: 0.8442 - 3ms/step\n", + "step 420/782 - loss: 0.4221 - acc: 0.8453 - 3ms/step\n", + "step 430/782 - loss: 0.3200 - acc: 0.8466 - 3ms/step\n", + "step 440/782 - loss: 0.3503 - acc: 0.8479 - 3ms/step\n", + "step 450/782 - loss: 0.4750 - acc: 0.8488 - 3ms/step\n", + "step 460/782 - loss: 0.4753 - acc: 0.8505 - 3ms/step\n", + "step 470/782 - loss: 0.5096 - acc: 0.8504 - 3ms/step\n", + "step 480/782 - loss: 0.4834 - acc: 0.8513 - 3ms/step\n", + "step 490/782 - loss: 0.3860 - acc: 0.8527 - 3ms/step\n", + "step 500/782 - loss: 0.5332 - acc: 0.8533 - 3ms/step\n", + "step 510/782 - loss: 0.4014 - acc: 0.8533 - 3ms/step\n", + "step 520/782 - loss: 0.4066 - acc: 0.8547 - 3ms/step\n", + "step 530/782 - loss: 0.4554 - acc: 0.8557 - 3ms/step\n", + "step 540/782 - loss: 0.5141 - acc: 0.8560 - 3ms/step\n", + "step 550/782 - loss: 0.4621 - acc: 0.8568 - 3ms/step\n", + "step 560/782 - loss: 0.4383 - acc: 0.8576 - 3ms/step\n", + "step 570/782 - loss: 0.3677 - acc: 0.8584 - 3ms/step\n", + "step 580/782 - loss: 0.5716 - acc: 0.8588 - 3ms/step\n", + "step 590/782 - loss: 0.4613 - acc: 0.8596 - 3ms/step\n", + "step 600/782 - loss: 0.4694 - acc: 0.8602 - 3ms/step\n", + "step 610/782 - loss: 0.3561 - acc: 0.8609 - 3ms/step\n", + "step 620/782 - loss: 0.4349 - acc: 0.8608 - 3ms/step\n", + "step 630/782 - loss: 0.4117 - acc: 0.8618 - 3ms/step\n", + "step 640/782 - loss: 0.3703 - acc: 0.8621 - 3ms/step\n", + "step 650/782 - loss: 0.3898 - acc: 0.8623 - 3ms/step\n", + "step 660/782 - loss: 0.4767 - acc: 0.8625 - 3ms/step\n", + "step 670/782 - loss: 0.4580 - acc: 0.8626 - 3ms/step\n", + "step 680/782 - loss: 0.4189 - acc: 0.8622 - 3ms/step\n", + "step 690/782 - loss: 0.4569 - acc: 0.8622 - 3ms/step\n", + "step 700/782 - loss: 0.3807 - acc: 0.8627 - 3ms/step\n", + "step 710/782 - loss: 0.4707 - acc: 0.8632 - 3ms/step\n", + "step 720/782 - loss: 0.3709 - acc: 0.8633 - 3ms/step\n", + "step 730/782 - loss: 0.4519 - acc: 0.8643 - 3ms/step\n", + "step 740/782 - loss: 0.4227 - acc: 0.8651 - 3ms/step\n", + "step 750/782 - loss: 0.4386 - acc: 0.8651 - 3ms/step\n", + "step 760/782 - loss: 0.3844 - acc: 0.8653 - 3ms/step\n", + "step 770/782 - loss: 0.3988 - acc: 0.8657 - 3ms/step\n", + "step 780/782 - loss: 0.3374 - acc: 0.8662 - 3ms/step\n", + "step 782/782 - loss: 0.4368 - acc: 0.8664 - 3ms/step\n", + "Eval samples: 25000\n", + "Predict begin...\n", + "step 10/10 [==============================] - 2ms/step \n", + "Predict samples: 10\n", + "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:1, 实际标签是:1\n" + ] + } + ], + "source": [ + "# 评估\r\n", + "model.evaluate(eval_data=DataReader(test_x, test_y, length), batch_size=32)\r\n", + "\r\n", + "# 预测\r\n", + "true_y = test_y[100:105] + test_y[-110:-105]\r\n", + "pred_y = model.predict(DataReader(test_x[100:105] + test_x[-110:-105], None, length), batch_size=1)\r\n", + "\r\n", + "for index, y in enumerate(pred_y[0]):\r\n", + " print(\"预测的标签是:%d, 实际标签是:%d\" % (np.argmax(y), true_y[index]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From 32778b065b74ff039dfe2a6626b652ef7fd07244 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Fri, 27 Nov 2020 16:35:11 +0800 Subject: [PATCH 02/14] =?UTF-8?q?=E9=A3=9E=E6=A1=A82.0=E5=AE=9E=E4=BE=8B?= =?UTF-8?q?=E6=95=99=E7=A8=8B=E2=80=94=E2=80=94=E4=BD=BF=E7=94=A8=E9=A2=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E8=AF=8D=E5=90=91=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除了1312-1315行多余代码,训练和评测verbose改为1并重新生成了输出。 --- .../pretrained_word_embeddings.ipynb | 963 +----------------- 1 file changed, 53 insertions(+), 910 deletions(-) diff --git a/paddle2.0_docs/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings.ipynb index 8242dc19..461892aa 100644 --- a/paddle2.0_docs/pretrained_word_embeddings.ipynb +++ b/paddle2.0_docs/pretrained_word_embeddings.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "collapsed": false }, @@ -66,11 +66,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Cache file /home/aistudio/.cache/paddle/dataset/imdb/imdb%2FaclImdb_v1.tar.gz not found, downloading https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz \n", + "Begin to download\n", + "\n", + "Download finished\n" + ] + } + ], "source": [ "imdb_train = pt.Imdb(mode='train', cutoff=150)\r\n", "imdb_test = pt.Imdb(mode='test', cutoff=150)" @@ -88,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -125,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -152,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "collapsed": false }, @@ -180,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "metadata": { "collapsed": false }, @@ -205,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false }, @@ -237,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false }, @@ -273,7 +284,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "collapsed": false }, @@ -300,15 +311,15 @@ "collapsed": false }, "source": [ - "观察词表的后5个单词,我们发现,最后一个词是\"\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", - "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('$$'并没有进行二进制编码,注意区分)。\n", + "观察词表的后5个单词,我们发现,最后一个词是\"\\\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", + "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('\\'并没有进行二进制编码,注意区分)。\n", "接下来,我们给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,我们设该词的词\n", "向量为零向量。" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -338,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false }, @@ -367,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "collapsed": false }, @@ -405,7 +416,7 @@ "{'total_params': 529692, 'trainable_params': 529692}" ] }, - "execution_count": null, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -450,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": { "collapsed": false }, @@ -460,838 +471,54 @@ "output_type": "stream", "text": [ "Epoch 1/10\n", - "step 10/586 - loss: 0.8757 - acc: 0.4813 - 18ms/step\n", - "step 20/586 - loss: 0.8331 - acc: 0.4828 - 13ms/step\n", - "step 30/586 - loss: 0.6944 - acc: 0.5042 - 11ms/step\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", - " return (isinstance(seq, collections.Sequence) and\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step 40/586 - loss: 0.7220 - acc: 0.5070 - 10ms/step\n", - "step 50/586 - loss: 0.6808 - acc: 0.4981 - 9ms/step\n", - "step 60/586 - loss: 0.7056 - acc: 0.5010 - 9ms/step\n", - "step 70/586 - loss: 0.6920 - acc: 0.5004 - 8ms/step\n", - "step 80/586 - loss: 0.6837 - acc: 0.5035 - 8ms/step\n", - "step 90/586 - loss: 0.6995 - acc: 0.4997 - 8ms/step\n", - "step 100/586 - loss: 0.6805 - acc: 0.5056 - 8ms/step\n", - "step 110/586 - loss: 0.6981 - acc: 0.5051 - 8ms/step\n", - "step 120/586 - loss: 0.7033 - acc: 0.5070 - 8ms/step\n", - "step 130/586 - loss: 0.7437 - acc: 0.5108 - 8ms/step\n", - "step 140/586 - loss: 0.6721 - acc: 0.5109 - 8ms/step\n", - "step 150/586 - loss: 0.6856 - acc: 0.5083 - 7ms/step\n", - "step 160/586 - loss: 0.6862 - acc: 0.5119 - 7ms/step\n", - "step 170/586 - loss: 0.6881 - acc: 0.5132 - 7ms/step\n", - "step 180/586 - loss: 0.6655 - acc: 0.5141 - 7ms/step\n", - "step 190/586 - loss: 0.6620 - acc: 0.5155 - 7ms/step\n", - "step 200/586 - loss: 0.6299 - acc: 0.5219 - 7ms/step\n", - "step 210/586 - loss: 0.7355 - acc: 0.5228 - 7ms/step\n", - "step 220/586 - loss: 0.6562 - acc: 0.5267 - 7ms/step\n", - "step 230/586 - loss: 0.6495 - acc: 0.5318 - 7ms/step\n", - "step 240/586 - loss: 0.6333 - acc: 0.5375 - 7ms/step\n", - "step 250/586 - loss: 0.6000 - acc: 0.5427 - 8ms/step\n", - "step 260/586 - loss: 0.5711 - acc: 0.5496 - 8ms/step\n", - "step 270/586 - loss: 0.5693 - acc: 0.5546 - 8ms/step\n", - "step 280/586 - loss: 0.6908 - acc: 0.5616 - 8ms/step\n", - "step 290/586 - loss: 0.6217 - acc: 0.5685 - 8ms/step\n", - "step 300/586 - loss: 0.5417 - acc: 0.5743 - 8ms/step\n", - "step 310/586 - loss: 0.5207 - acc: 0.5780 - 8ms/step\n", - "step 320/586 - loss: 0.5410 - acc: 0.5841 - 8ms/step\n", - "step 330/586 - loss: 0.5647 - acc: 0.5883 - 8ms/step\n", - "step 340/586 - loss: 0.4975 - acc: 0.5930 - 8ms/step\n", - "step 350/586 - loss: 0.5611 - acc: 0.5988 - 8ms/step\n", - "step 360/586 - loss: 0.5176 - acc: 0.6044 - 8ms/step\n", - "step 370/586 - loss: 0.4878 - acc: 0.6087 - 8ms/step\n", - "step 380/586 - loss: 0.5079 - acc: 0.6131 - 8ms/step\n", - "step 390/586 - loss: 0.4918 - acc: 0.6178 - 8ms/step\n", - "step 400/586 - loss: 0.4999 - acc: 0.6220 - 8ms/step\n", - "step 410/586 - loss: 0.5087 - acc: 0.6254 - 8ms/step\n", - "step 420/586 - loss: 0.4500 - acc: 0.6286 - 8ms/step\n", - "step 430/586 - loss: 0.4677 - acc: 0.6338 - 8ms/step\n", - "step 440/586 - loss: 0.4354 - acc: 0.6377 - 8ms/step\n", - "step 450/586 - loss: 0.4049 - acc: 0.6424 - 8ms/step\n", - "step 460/586 - loss: 0.4874 - acc: 0.6459 - 8ms/step\n", - "step 470/586 - loss: 0.6287 - acc: 0.6497 - 8ms/step\n", - "step 480/586 - loss: 0.4633 - acc: 0.6535 - 8ms/step\n", - "step 490/586 - loss: 0.4972 - acc: 0.6573 - 8ms/step\n", - "step 500/586 - loss: 0.5369 - acc: 0.6603 - 8ms/step\n", - "step 510/586 - loss: 0.5170 - acc: 0.6634 - 8ms/step\n", - "step 520/586 - loss: 0.4569 - acc: 0.6665 - 8ms/step\n", - "step 530/586 - loss: 0.4837 - acc: 0.6696 - 8ms/step\n", - "step 540/586 - loss: 0.4510 - acc: 0.6726 - 8ms/step\n", - "step 550/586 - loss: 0.5162 - acc: 0.6756 - 8ms/step\n", - "step 560/586 - loss: 0.4821 - acc: 0.6781 - 8ms/step\n", - "step 570/586 - loss: 0.4589 - acc: 0.6806 - 8ms/step\n", - "step 580/586 - loss: 0.4688 - acc: 0.6830 - 8ms/step\n", - "step 586/586 - loss: 0.4162 - acc: 0.6847 - 8ms/step\n", + "step 586/586 [==============================] - loss: 0.3736 - acc: 0.9740 - 6ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4399 - acc: 0.8313 - 3ms/step\n", - "step 20/196 - loss: 0.4896 - acc: 0.8266 - 2ms/step\n", - "step 30/196 - loss: 0.6432 - acc: 0.8187 - 2ms/step\n", - "step 40/196 - loss: 0.4953 - acc: 0.8156 - 2ms/step\n", - "step 50/196 - loss: 0.4499 - acc: 0.8081 - 2ms/step\n", - "step 60/196 - loss: 0.4401 - acc: 0.8130 - 2ms/step\n", - "step 70/196 - loss: 0.4320 - acc: 0.8121 - 2ms/step\n", - "step 80/196 - loss: 0.5158 - acc: 0.8102 - 2ms/step\n", - "step 90/196 - loss: 0.6223 - acc: 0.8115 - 2ms/step\n", - "step 100/196 - loss: 0.4908 - acc: 0.8172 - 2ms/step\n", - "step 110/196 - loss: 0.4968 - acc: 0.8173 - 2ms/step\n", - "step 120/196 - loss: 0.4446 - acc: 0.8161 - 2ms/step\n", - "step 130/196 - loss: 0.4763 - acc: 0.8159 - 2ms/step\n", - "step 140/196 - loss: 0.4702 - acc: 0.8174 - 2ms/step\n", - "step 150/196 - loss: 0.5083 - acc: 0.8163 - 2ms/step\n", - "step 160/196 - loss: 0.5015 - acc: 0.8139 - 2ms/step\n", - "step 170/196 - loss: 0.5416 - acc: 0.8116 - 2ms/step\n", - "step 180/196 - loss: 0.4286 - acc: 0.8120 - 2ms/step\n", - "step 190/196 - loss: 0.5156 - acc: 0.8123 - 2ms/step\n", - "step 196/196 - loss: 0.5552 - acc: 0.8122 - 2ms/step\n", + "step 196/196 [==============================] - loss: 0.5626 - acc: 0.8726 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 2/10\n", - "step 10/586 - loss: 0.4843 - acc: 0.8375 - 7ms/step\n", - "step 20/586 - loss: 0.4507 - acc: 0.8516 - 7ms/step\n", - "step 30/586 - loss: 0.5005 - acc: 0.8521 - 7ms/step\n", - "step 40/586 - loss: 0.4608 - acc: 0.8531 - 7ms/step\n", - "step 50/586 - loss: 0.4466 - acc: 0.8481 - 7ms/step\n", - "step 60/586 - loss: 0.5826 - acc: 0.8406 - 7ms/step\n", - "step 70/586 - loss: 0.4946 - acc: 0.8415 - 7ms/step\n", - "step 80/586 - loss: 0.4346 - acc: 0.8410 - 7ms/step\n", - "step 90/586 - loss: 0.4112 - acc: 0.8465 - 7ms/step\n", - "step 100/586 - loss: 0.4780 - acc: 0.8472 - 7ms/step\n", - "step 110/586 - loss: 0.4085 - acc: 0.8477 - 7ms/step\n", - "step 120/586 - loss: 0.4291 - acc: 0.8490 - 7ms/step\n", - "step 130/586 - loss: 0.4203 - acc: 0.8498 - 7ms/step\n", - "step 140/586 - loss: 0.4696 - acc: 0.8496 - 7ms/step\n", - "step 150/586 - loss: 0.4195 - acc: 0.8502 - 7ms/step\n", - "step 160/586 - loss: 0.4378 - acc: 0.8520 - 7ms/step\n", - "step 170/586 - loss: 0.4465 - acc: 0.8528 - 7ms/step\n", - "step 180/586 - loss: 0.4533 - acc: 0.8535 - 7ms/step\n", - "step 190/586 - loss: 0.4143 - acc: 0.8556 - 7ms/step\n", - "step 200/586 - loss: 0.4385 - acc: 0.8567 - 7ms/step\n", - "step 210/586 - loss: 0.4712 - acc: 0.8580 - 7ms/step\n", - "step 220/586 - loss: 0.4541 - acc: 0.8587 - 7ms/step\n", - "step 230/586 - loss: 0.5102 - acc: 0.8598 - 7ms/step\n", - "step 240/586 - loss: 0.4461 - acc: 0.8604 - 7ms/step\n", - "step 250/586 - loss: 0.4888 - acc: 0.8598 - 7ms/step\n", - "step 260/586 - loss: 0.4808 - acc: 0.8594 - 7ms/step\n", - "step 270/586 - loss: 0.3762 - acc: 0.8600 - 7ms/step\n", - "step 280/586 - loss: 0.4755 - acc: 0.8609 - 7ms/step\n", - "step 290/586 - loss: 0.4851 - acc: 0.8610 - 7ms/step\n", - "step 300/586 - loss: 0.4570 - acc: 0.8615 - 7ms/step\n", - "step 310/586 - loss: 0.4403 - acc: 0.8611 - 7ms/step\n", - "step 320/586 - loss: 0.3967 - acc: 0.8611 - 7ms/step\n", - "step 330/586 - loss: 0.5665 - acc: 0.8614 - 7ms/step\n", - "step 340/586 - loss: 0.4581 - acc: 0.8616 - 7ms/step\n", - "step 350/586 - loss: 0.4790 - acc: 0.8614 - 7ms/step\n", - "step 360/586 - loss: 0.4301 - acc: 0.8619 - 7ms/step\n", - "step 370/586 - loss: 0.4055 - acc: 0.8617 - 7ms/step\n", - "step 380/586 - loss: 0.3873 - acc: 0.8626 - 7ms/step\n", - "step 390/586 - loss: 0.3884 - acc: 0.8635 - 7ms/step\n", - "step 400/586 - loss: 0.3815 - acc: 0.8634 - 7ms/step\n", - "step 410/586 - loss: 0.4561 - acc: 0.8633 - 7ms/step\n", - "step 420/586 - loss: 0.4677 - acc: 0.8631 - 7ms/step\n", - "step 430/586 - loss: 0.4463 - acc: 0.8624 - 7ms/step\n", - "step 440/586 - loss: 0.4642 - acc: 0.8624 - 7ms/step\n", - "step 450/586 - loss: 0.4780 - acc: 0.8626 - 7ms/step\n", - "step 460/586 - loss: 0.4521 - acc: 0.8627 - 7ms/step\n", - "step 470/586 - loss: 0.4318 - acc: 0.8628 - 7ms/step\n", - "step 480/586 - loss: 0.4390 - acc: 0.8628 - 7ms/step\n", - "step 490/586 - loss: 0.4787 - acc: 0.8629 - 7ms/step\n", - "step 500/586 - loss: 0.4620 - acc: 0.8631 - 7ms/step\n", - "step 510/586 - loss: 0.5165 - acc: 0.8631 - 7ms/step\n", - "step 520/586 - loss: 0.4316 - acc: 0.8623 - 7ms/step\n", - "step 530/586 - loss: 0.3964 - acc: 0.8627 - 7ms/step\n", - "step 540/586 - loss: 0.4333 - acc: 0.8631 - 7ms/step\n", - "step 550/586 - loss: 0.3577 - acc: 0.8629 - 7ms/step\n", - "step 560/586 - loss: 0.4475 - acc: 0.8631 - 7ms/step\n", - "step 570/586 - loss: 0.3820 - acc: 0.8634 - 7ms/step\n", - "step 580/586 - loss: 0.4899 - acc: 0.8636 - 7ms/step\n", - "step 586/586 - loss: 0.3425 - acc: 0.8641 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3499 - acc: 0.9748 - 6ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4062 - acc: 0.8781 - 3ms/step\n", - "step 20/196 - loss: 0.4372 - acc: 0.8781 - 3ms/step\n", - "step 30/196 - loss: 0.5886 - acc: 0.8750 - 3ms/step\n", - "step 40/196 - loss: 0.4661 - acc: 0.8648 - 3ms/step\n", - "step 50/196 - loss: 0.4340 - acc: 0.8612 - 3ms/step\n", - "step 60/196 - loss: 0.4301 - acc: 0.8604 - 3ms/step\n", - "step 70/196 - loss: 0.4055 - acc: 0.8616 - 3ms/step\n", - "step 80/196 - loss: 0.4645 - acc: 0.8590 - 3ms/step\n", - "step 90/196 - loss: 0.5809 - acc: 0.8597 - 3ms/step\n", - "step 100/196 - loss: 0.4399 - acc: 0.8606 - 3ms/step\n", - "step 110/196 - loss: 0.4577 - acc: 0.8608 - 3ms/step\n", - "step 120/196 - loss: 0.3500 - acc: 0.8581 - 3ms/step\n", - "step 130/196 - loss: 0.4330 - acc: 0.8587 - 3ms/step\n", - "step 140/196 - loss: 0.4096 - acc: 0.8603 - 3ms/step\n", - "step 150/196 - loss: 0.4189 - acc: 0.8602 - 3ms/step\n", - "step 160/196 - loss: 0.4849 - acc: 0.8588 - 3ms/step\n", - "step 170/196 - loss: 0.4570 - acc: 0.8590 - 3ms/step\n", - "step 180/196 - loss: 0.3667 - acc: 0.8601 - 3ms/step\n", - "step 190/196 - loss: 0.4623 - acc: 0.8604 - 3ms/step\n", - "step 196/196 - loss: 0.5284 - acc: 0.8619 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.7976 - acc: 0.8651 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 3/10\n", - "step 10/586 - loss: 0.4269 - acc: 0.8875 - 7ms/step\n", - "step 20/586 - loss: 0.3295 - acc: 0.9031 - 7ms/step\n", - "step 30/586 - loss: 0.4543 - acc: 0.9062 - 7ms/step\n", - "step 40/586 - loss: 0.3627 - acc: 0.9102 - 7ms/step\n", - "step 50/586 - loss: 0.4724 - acc: 0.9087 - 7ms/step\n", - "step 60/586 - loss: 0.4065 - acc: 0.9104 - 7ms/step\n", - "step 70/586 - loss: 0.3910 - acc: 0.9134 - 7ms/step\n", - "step 80/586 - loss: 0.4536 - acc: 0.9086 - 7ms/step\n", - "step 90/586 - loss: 0.4164 - acc: 0.9052 - 7ms/step\n", - "step 100/586 - loss: 0.5490 - acc: 0.8994 - 7ms/step\n", - "step 110/586 - loss: 0.4750 - acc: 0.8952 - 7ms/step\n", - "step 120/586 - loss: 0.3541 - acc: 0.8964 - 7ms/step\n", - "step 130/586 - loss: 0.3955 - acc: 0.8974 - 7ms/step\n", - "step 140/586 - loss: 0.4073 - acc: 0.8971 - 7ms/step\n", - "step 150/586 - loss: 0.4303 - acc: 0.8985 - 7ms/step\n", - "step 160/586 - loss: 0.4012 - acc: 0.8984 - 7ms/step\n", - "step 170/586 - loss: 0.4510 - acc: 0.8987 - 7ms/step\n", - "step 180/586 - loss: 0.4806 - acc: 0.8993 - 7ms/step\n", - "step 190/586 - loss: 0.4275 - acc: 0.8998 - 7ms/step\n", - "step 200/586 - loss: 0.4005 - acc: 0.8995 - 7ms/step\n", - "step 210/586 - loss: 0.4164 - acc: 0.8994 - 7ms/step\n", - "step 220/586 - loss: 0.4389 - acc: 0.8999 - 7ms/step\n", - "step 230/586 - loss: 0.4320 - acc: 0.9003 - 7ms/step\n", - "step 240/586 - loss: 0.4554 - acc: 0.8995 - 7ms/step\n", - "step 250/586 - loss: 0.4506 - acc: 0.8986 - 7ms/step\n", - "step 260/586 - loss: 0.3554 - acc: 0.8987 - 7ms/step\n", - "step 270/586 - loss: 0.4138 - acc: 0.8992 - 7ms/step\n", - "step 280/586 - loss: 0.3524 - acc: 0.8987 - 7ms/step\n", - "step 290/586 - loss: 0.3577 - acc: 0.8995 - 7ms/step\n", - "step 300/586 - loss: 0.3739 - acc: 0.8996 - 7ms/step\n", - "step 310/586 - loss: 0.3896 - acc: 0.8996 - 7ms/step\n", - "step 320/586 - loss: 0.3983 - acc: 0.9000 - 7ms/step\n", - "step 330/586 - loss: 0.4169 - acc: 0.9001 - 7ms/step\n", - "step 340/586 - loss: 0.4219 - acc: 0.8982 - 7ms/step\n", - "step 350/586 - loss: 0.5360 - acc: 0.8988 - 7ms/step\n", - "step 360/586 - loss: 0.3557 - acc: 0.8984 - 7ms/step\n", - "step 370/586 - loss: 0.4556 - acc: 0.8978 - 7ms/step\n", - "step 380/586 - loss: 0.3822 - acc: 0.8975 - 7ms/step\n", - "step 390/586 - loss: 0.4795 - acc: 0.8967 - 7ms/step\n", - "step 400/586 - loss: 0.4399 - acc: 0.8965 - 7ms/step\n", - "step 410/586 - loss: 0.4165 - acc: 0.8963 - 7ms/step\n", - "step 420/586 - loss: 0.4211 - acc: 0.8968 - 7ms/step\n", - "step 430/586 - loss: 0.3752 - acc: 0.8971 - 7ms/step\n", - "step 440/586 - loss: 0.4722 - acc: 0.8962 - 7ms/step\n", - "step 450/586 - loss: 0.3402 - acc: 0.8963 - 7ms/step\n", - "step 460/586 - loss: 0.4418 - acc: 0.8967 - 7ms/step\n", - "step 470/586 - loss: 0.3263 - acc: 0.8975 - 7ms/step\n", - "step 480/586 - loss: 0.3991 - acc: 0.8974 - 7ms/step\n", - "step 490/586 - loss: 0.3989 - acc: 0.8979 - 7ms/step\n", - "step 500/586 - loss: 0.4587 - acc: 0.8978 - 7ms/step\n", - "step 510/586 - loss: 0.3556 - acc: 0.8975 - 7ms/step\n", - "step 520/586 - loss: 0.4912 - acc: 0.8977 - 7ms/step\n", - "step 530/586 - loss: 0.4094 - acc: 0.8979 - 7ms/step\n", - "step 540/586 - loss: 0.3773 - acc: 0.8984 - 7ms/step\n", - "step 550/586 - loss: 0.4833 - acc: 0.8980 - 7ms/step\n", - "step 560/586 - loss: 0.3811 - acc: 0.8980 - 7ms/step\n", - "step 570/586 - loss: 0.4198 - acc: 0.8978 - 7ms/step\n", - "step 580/586 - loss: 0.3985 - acc: 0.8984 - 7ms/step\n", - "step 586/586 - loss: 0.4302 - acc: 0.8987 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3137 - acc: 0.9756 - 6ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4235 - acc: 0.8531 - 3ms/step\n", - "step 20/196 - loss: 0.4380 - acc: 0.8562 - 3ms/step\n", - "step 30/196 - loss: 0.5421 - acc: 0.8583 - 3ms/step\n", - "step 40/196 - loss: 0.4682 - acc: 0.8562 - 3ms/step\n", - "step 50/196 - loss: 0.4120 - acc: 0.8588 - 3ms/step\n", - "step 60/196 - loss: 0.3863 - acc: 0.8589 - 3ms/step\n", - "step 70/196 - loss: 0.4057 - acc: 0.8634 - 3ms/step\n", - "step 80/196 - loss: 0.4562 - acc: 0.8633 - 3ms/step\n", - "step 90/196 - loss: 0.5596 - acc: 0.8632 - 3ms/step\n", - "step 100/196 - loss: 0.4493 - acc: 0.8653 - 3ms/step\n", - "step 110/196 - loss: 0.4656 - acc: 0.8639 - 3ms/step\n", - "step 120/196 - loss: 0.3922 - acc: 0.8604 - 3ms/step\n", - "step 130/196 - loss: 0.4482 - acc: 0.8608 - 3ms/step\n", - "step 140/196 - loss: 0.3829 - acc: 0.8632 - 3ms/step\n", - "step 150/196 - loss: 0.4171 - acc: 0.8638 - 3ms/step\n", - "step 160/196 - loss: 0.4876 - acc: 0.8615 - 3ms/step\n", - "step 170/196 - loss: 0.4649 - acc: 0.8608 - 3ms/step\n", - "step 180/196 - loss: 0.3737 - acc: 0.8627 - 3ms/step\n", - "step 190/196 - loss: 0.4659 - acc: 0.8620 - 3ms/step\n", - "step 196/196 - loss: 0.4331 - acc: 0.8634 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.6264 - acc: 0.8701 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 4/10\n", - "step 10/586 - loss: 0.4649 - acc: 0.8938 - 7ms/step\n", - "step 20/586 - loss: 0.4502 - acc: 0.8891 - 7ms/step\n", - "step 30/586 - loss: 0.3967 - acc: 0.8969 - 7ms/step\n", - "step 40/586 - loss: 0.3733 - acc: 0.9000 - 7ms/step\n", - "step 50/586 - loss: 0.4118 - acc: 0.9094 - 7ms/step\n", - "step 60/586 - loss: 0.3935 - acc: 0.9094 - 7ms/step\n", - "step 70/586 - loss: 0.3910 - acc: 0.9125 - 7ms/step\n", - "step 80/586 - loss: 0.3524 - acc: 0.9168 - 7ms/step\n", - "step 90/586 - loss: 0.3936 - acc: 0.9184 - 7ms/step\n", - "step 100/586 - loss: 0.3414 - acc: 0.9219 - 7ms/step\n", - "step 110/586 - loss: 0.3739 - acc: 0.9244 - 7ms/step\n", - "step 120/586 - loss: 0.4057 - acc: 0.9237 - 7ms/step\n", - "step 130/586 - loss: 0.3796 - acc: 0.9226 - 7ms/step\n", - "step 140/586 - loss: 0.3649 - acc: 0.9219 - 7ms/step\n", - "step 150/586 - loss: 0.3848 - acc: 0.9208 - 7ms/step\n", - "step 160/586 - loss: 0.4138 - acc: 0.9207 - 7ms/step\n", - "step 170/586 - loss: 0.3893 - acc: 0.9219 - 7ms/step\n", - "step 180/586 - loss: 0.3575 - acc: 0.9229 - 7ms/step\n", - "step 190/586 - loss: 0.3528 - acc: 0.9248 - 7ms/step\n", - "step 200/586 - loss: 0.4436 - acc: 0.9231 - 7ms/step\n", - "step 210/586 - loss: 0.3936 - acc: 0.9232 - 7ms/step\n", - "step 220/586 - loss: 0.3917 - acc: 0.9213 - 7ms/step\n", - "step 230/586 - loss: 0.3866 - acc: 0.9219 - 7ms/step\n", - "step 240/586 - loss: 0.4124 - acc: 0.9224 - 7ms/step\n", - "step 250/586 - loss: 0.4374 - acc: 0.9215 - 7ms/step\n", - "step 260/586 - loss: 0.3602 - acc: 0.9218 - 7ms/step\n", - "step 270/586 - loss: 0.3354 - acc: 0.9223 - 7ms/step\n", - "step 280/586 - loss: 0.4723 - acc: 0.9220 - 7ms/step\n", - "step 290/586 - loss: 0.3258 - acc: 0.9230 - 7ms/step\n", - "step 300/586 - loss: 0.3674 - acc: 0.9236 - 7ms/step\n", - "step 310/586 - loss: 0.3226 - acc: 0.9241 - 6ms/step\n", - "step 320/586 - loss: 0.3961 - acc: 0.9241 - 6ms/step\n", - "step 330/586 - loss: 0.4282 - acc: 0.9237 - 6ms/step\n", - "step 340/586 - loss: 0.3943 - acc: 0.9235 - 6ms/step\n", - "step 350/586 - loss: 0.4288 - acc: 0.9224 - 6ms/step\n", - "step 360/586 - loss: 0.4189 - acc: 0.9221 - 7ms/step\n", - "step 370/586 - loss: 0.4015 - acc: 0.9227 - 7ms/step\n", - "step 380/586 - loss: 0.3946 - acc: 0.9230 - 7ms/step\n", - "step 390/586 - loss: 0.3763 - acc: 0.9233 - 7ms/step\n", - "step 400/586 - loss: 0.3684 - acc: 0.9232 - 7ms/step\n", - "step 410/586 - loss: 0.3471 - acc: 0.9233 - 7ms/step\n", - "step 420/586 - loss: 0.4221 - acc: 0.9234 - 7ms/step\n", - "step 430/586 - loss: 0.4527 - acc: 0.9232 - 7ms/step\n", - "step 440/586 - loss: 0.3835 - acc: 0.9233 - 7ms/step\n", - "step 450/586 - loss: 0.4414 - acc: 0.9233 - 7ms/step\n", - "step 460/586 - loss: 0.3542 - acc: 0.9235 - 7ms/step\n", - "step 470/586 - loss: 0.3878 - acc: 0.9236 - 7ms/step\n", - "step 480/586 - loss: 0.4531 - acc: 0.9235 - 7ms/step\n", - "step 490/586 - loss: 0.4480 - acc: 0.9234 - 7ms/step\n", - "step 500/586 - loss: 0.3302 - acc: 0.9239 - 7ms/step\n", - "step 510/586 - loss: 0.3513 - acc: 0.9238 - 7ms/step\n", - "step 520/586 - loss: 0.4588 - acc: 0.9237 - 7ms/step\n", - "step 530/586 - loss: 0.3953 - acc: 0.9238 - 7ms/step\n", - "step 540/586 - loss: 0.4340 - acc: 0.9242 - 7ms/step\n", - "step 550/586 - loss: 0.3836 - acc: 0.9243 - 7ms/step\n", - "step 560/586 - loss: 0.3799 - acc: 0.9241 - 7ms/step\n", - "step 570/586 - loss: 0.4244 - acc: 0.9240 - 7ms/step\n", - "step 580/586 - loss: 0.3150 - acc: 0.9236 - 7ms/step\n", - "step 586/586 - loss: 0.5743 - acc: 0.9230 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3470 - acc: 0.9772 - 6ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.3942 - acc: 0.8906 - 2ms/step\n", - "step 20/196 - loss: 0.4010 - acc: 0.8891 - 2ms/step\n", - "step 30/196 - loss: 0.5784 - acc: 0.8750 - 2ms/step\n", - "step 40/196 - loss: 0.4673 - acc: 0.8703 - 2ms/step\n", - "step 50/196 - loss: 0.4671 - acc: 0.8669 - 2ms/step\n", - "step 60/196 - loss: 0.4023 - acc: 0.8656 - 2ms/step\n", - "step 70/196 - loss: 0.4319 - acc: 0.8679 - 2ms/step\n", - "step 80/196 - loss: 0.4205 - acc: 0.8664 - 2ms/step\n", - "step 90/196 - loss: 0.5517 - acc: 0.8656 - 2ms/step\n", - "step 100/196 - loss: 0.4190 - acc: 0.8675 - 2ms/step\n", - "step 110/196 - loss: 0.4450 - acc: 0.8682 - 2ms/step\n", - "step 120/196 - loss: 0.3771 - acc: 0.8651 - 2ms/step\n", - "step 130/196 - loss: 0.4033 - acc: 0.8659 - 2ms/step\n", - "step 140/196 - loss: 0.4189 - acc: 0.8667 - 2ms/step\n", - "step 150/196 - loss: 0.4362 - acc: 0.8660 - 2ms/step\n", - "step 160/196 - loss: 0.5045 - acc: 0.8643 - 2ms/step\n", - "step 170/196 - loss: 0.3803 - acc: 0.8651 - 2ms/step\n", - "step 180/196 - loss: 0.3570 - acc: 0.8672 - 2ms/step\n", - "step 190/196 - loss: 0.4183 - acc: 0.8679 - 2ms/step\n", - "step 196/196 - loss: 0.5245 - acc: 0.8683 - 2ms/step\n", + "step 196/196 [==============================] - loss: 0.6550 - acc: 0.8714 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 5/10\n", - "step 10/586 - loss: 0.3663 - acc: 0.9437 - 7ms/step\n", - "step 20/586 - loss: 0.3953 - acc: 0.9531 - 7ms/step\n", - "step 30/586 - loss: 0.4353 - acc: 0.9448 - 7ms/step\n", - "step 40/586 - loss: 0.4004 - acc: 0.9445 - 7ms/step\n", - "step 50/586 - loss: 0.3962 - acc: 0.9437 - 7ms/step\n", - "step 60/586 - loss: 0.3936 - acc: 0.9453 - 7ms/step\n", - "step 70/586 - loss: 0.3608 - acc: 0.9455 - 6ms/step\n", - "step 80/586 - loss: 0.3816 - acc: 0.9441 - 6ms/step\n", - "step 90/586 - loss: 0.4682 - acc: 0.9437 - 6ms/step\n", - "step 100/586 - loss: 0.3616 - acc: 0.9428 - 6ms/step\n", - "step 110/586 - loss: 0.4110 - acc: 0.9432 - 6ms/step\n", - "step 120/586 - loss: 0.3548 - acc: 0.9437 - 6ms/step\n", - "step 130/586 - loss: 0.3788 - acc: 0.9433 - 6ms/step\n", - "step 140/586 - loss: 0.3626 - acc: 0.9433 - 6ms/step\n", - "step 150/586 - loss: 0.3856 - acc: 0.9435 - 6ms/step\n", - "step 160/586 - loss: 0.4348 - acc: 0.9437 - 6ms/step\n", - "step 170/586 - loss: 0.3337 - acc: 0.9443 - 6ms/step\n", - "step 180/586 - loss: 0.3341 - acc: 0.9439 - 6ms/step\n", - "step 190/586 - loss: 0.3483 - acc: 0.9434 - 6ms/step\n", - "step 200/586 - loss: 0.3253 - acc: 0.9431 - 6ms/step\n", - "step 210/586 - loss: 0.3671 - acc: 0.9418 - 6ms/step\n", - "step 220/586 - loss: 0.3685 - acc: 0.9415 - 6ms/step\n", - "step 230/586 - loss: 0.4182 - acc: 0.9413 - 6ms/step\n", - "step 240/586 - loss: 0.3367 - acc: 0.9410 - 6ms/step\n", - "step 250/586 - loss: 0.4380 - acc: 0.9407 - 6ms/step\n", - "step 260/586 - loss: 0.3579 - acc: 0.9394 - 6ms/step\n", - "step 270/586 - loss: 0.3499 - acc: 0.9388 - 6ms/step\n", - "step 280/586 - loss: 0.4419 - acc: 0.9384 - 6ms/step\n", - "step 290/586 - loss: 0.4185 - acc: 0.9378 - 6ms/step\n", - "step 300/586 - loss: 0.4595 - acc: 0.9375 - 6ms/step\n", - "step 310/586 - loss: 0.3226 - acc: 0.9378 - 6ms/step\n", - "step 320/586 - loss: 0.3661 - acc: 0.9382 - 6ms/step\n", - "step 330/586 - loss: 0.3806 - acc: 0.9383 - 6ms/step\n", - "step 340/586 - loss: 0.4106 - acc: 0.9380 - 6ms/step\n", - "step 350/586 - loss: 0.4062 - acc: 0.9375 - 6ms/step\n", - "step 360/586 - loss: 0.3989 - acc: 0.9375 - 6ms/step\n", - "step 370/586 - loss: 0.3514 - acc: 0.9383 - 6ms/step\n", - "step 380/586 - loss: 0.3183 - acc: 0.9391 - 6ms/step\n", - "step 390/586 - loss: 0.3472 - acc: 0.9395 - 6ms/step\n", - "step 400/586 - loss: 0.3165 - acc: 0.9393 - 6ms/step\n", - "step 410/586 - loss: 0.3192 - acc: 0.9393 - 6ms/step\n", - "step 420/586 - loss: 0.3826 - acc: 0.9394 - 7ms/step\n", - "step 430/586 - loss: 0.3252 - acc: 0.9401 - 7ms/step\n", - "step 440/586 - loss: 0.3815 - acc: 0.9406 - 7ms/step\n", - "step 450/586 - loss: 0.3926 - acc: 0.9408 - 7ms/step\n", - "step 460/586 - loss: 0.4072 - acc: 0.9411 - 7ms/step\n", - "step 470/586 - loss: 0.4134 - acc: 0.9412 - 7ms/step\n", - "step 480/586 - loss: 0.3375 - acc: 0.9413 - 7ms/step\n", - "step 490/586 - loss: 0.3880 - acc: 0.9414 - 7ms/step\n", - "step 500/586 - loss: 0.3885 - acc: 0.9417 - 7ms/step\n", - "step 510/586 - loss: 0.3638 - acc: 0.9417 - 7ms/step\n", - "step 520/586 - loss: 0.4671 - acc: 0.9414 - 7ms/step\n", - "step 530/586 - loss: 0.3618 - acc: 0.9412 - 7ms/step\n", - "step 540/586 - loss: 0.3202 - acc: 0.9409 - 7ms/step\n", - "step 550/586 - loss: 0.3325 - acc: 0.9405 - 7ms/step\n", - "step 560/586 - loss: 0.3969 - acc: 0.9403 - 7ms/step\n", - "step 570/586 - loss: 0.3870 - acc: 0.9399 - 7ms/step\n", - "step 580/586 - loss: 0.3297 - acc: 0.9402 - 7ms/step\n", - "step 586/586 - loss: 0.3533 - acc: 0.9400 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3507 - acc: 0.9776 - 7ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.3991 - acc: 0.8812 - 3ms/step\n", - "step 20/196 - loss: 0.4031 - acc: 0.8875 - 2ms/step\n", - "step 30/196 - loss: 0.5758 - acc: 0.8760 - 2ms/step\n", - "step 40/196 - loss: 0.4588 - acc: 0.8695 - 3ms/step\n", - "step 50/196 - loss: 0.4694 - acc: 0.8669 - 3ms/step\n", - "step 60/196 - loss: 0.4034 - acc: 0.8661 - 3ms/step\n", - "step 70/196 - loss: 0.4236 - acc: 0.8714 - 3ms/step\n", - "step 80/196 - loss: 0.4264 - acc: 0.8703 - 3ms/step\n", - "step 90/196 - loss: 0.5121 - acc: 0.8698 - 3ms/step\n", - "step 100/196 - loss: 0.3963 - acc: 0.8709 - 3ms/step\n", - "step 110/196 - loss: 0.4396 - acc: 0.8716 - 3ms/step\n", - "step 120/196 - loss: 0.3787 - acc: 0.8680 - 3ms/step\n", - "step 130/196 - loss: 0.4081 - acc: 0.8678 - 3ms/step\n", - "step 140/196 - loss: 0.4171 - acc: 0.8676 - 3ms/step\n", - "step 150/196 - loss: 0.4276 - acc: 0.8675 - 3ms/step\n", - "step 160/196 - loss: 0.5145 - acc: 0.8660 - 3ms/step\n", - "step 170/196 - loss: 0.3994 - acc: 0.8664 - 3ms/step\n", - "step 180/196 - loss: 0.3495 - acc: 0.8686 - 3ms/step\n", - "step 190/196 - loss: 0.4370 - acc: 0.8696 - 3ms/step\n", - "step 196/196 - loss: 0.4342 - acc: 0.8706 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.7118 - acc: 0.8726 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 6/10\n", - "step 10/586 - loss: 0.3305 - acc: 0.9656 - 7ms/step\n", - "step 20/586 - loss: 0.3285 - acc: 0.9641 - 7ms/step\n", - "step 30/586 - loss: 0.3835 - acc: 0.9563 - 8ms/step\n", - "step 40/586 - loss: 0.4051 - acc: 0.9492 - 7ms/step\n", - "step 50/586 - loss: 0.3310 - acc: 0.9506 - 7ms/step\n", - "step 60/586 - loss: 0.3157 - acc: 0.9542 - 7ms/step\n", - "step 70/586 - loss: 0.3776 - acc: 0.9540 - 7ms/step\n", - "step 80/586 - loss: 0.4235 - acc: 0.9531 - 8ms/step\n", - "step 90/586 - loss: 0.3765 - acc: 0.9538 - 8ms/step\n", - "step 100/586 - loss: 0.4109 - acc: 0.9537 - 8ms/step\n", - "step 110/586 - loss: 0.3178 - acc: 0.9548 - 8ms/step\n", - "step 120/586 - loss: 0.3332 - acc: 0.9560 - 8ms/step\n", - "step 130/586 - loss: 0.3541 - acc: 0.9560 - 8ms/step\n", - "step 140/586 - loss: 0.4426 - acc: 0.9551 - 8ms/step\n", - "step 150/586 - loss: 0.3988 - acc: 0.9550 - 8ms/step\n", - "step 160/586 - loss: 0.3752 - acc: 0.9553 - 8ms/step\n", - "step 170/586 - loss: 0.3670 - acc: 0.9548 - 8ms/step\n", - "step 180/586 - loss: 0.3524 - acc: 0.9542 - 8ms/step\n", - "step 190/586 - loss: 0.4168 - acc: 0.9531 - 8ms/step\n", - "step 200/586 - loss: 0.4119 - acc: 0.9536 - 8ms/step\n", - "step 210/586 - loss: 0.3779 - acc: 0.9533 - 8ms/step\n", - "step 220/586 - loss: 0.4391 - acc: 0.9536 - 8ms/step\n", - "step 230/586 - loss: 0.3181 - acc: 0.9537 - 8ms/step\n", - "step 240/586 - loss: 0.3546 - acc: 0.9543 - 8ms/step\n", - "step 250/586 - loss: 0.3768 - acc: 0.9545 - 8ms/step\n", - "step 260/586 - loss: 0.3607 - acc: 0.9544 - 7ms/step\n", - "step 270/586 - loss: 0.3783 - acc: 0.9546 - 7ms/step\n", - "step 280/586 - loss: 0.3453 - acc: 0.9542 - 7ms/step\n", - "step 290/586 - loss: 0.3470 - acc: 0.9552 - 7ms/step\n", - "step 300/586 - loss: 0.3719 - acc: 0.9547 - 7ms/step\n", - "step 310/586 - loss: 0.3817 - acc: 0.9542 - 7ms/step\n", - "step 320/586 - loss: 0.3873 - acc: 0.9546 - 7ms/step\n", - "step 330/586 - loss: 0.3214 - acc: 0.9545 - 7ms/step\n", - "step 340/586 - loss: 0.3188 - acc: 0.9546 - 7ms/step\n", - "step 350/586 - loss: 0.4134 - acc: 0.9546 - 7ms/step\n", - "step 360/586 - loss: 0.3154 - acc: 0.9549 - 7ms/step\n", - "step 370/586 - loss: 0.3639 - acc: 0.9550 - 7ms/step\n", - "step 380/586 - loss: 0.3960 - acc: 0.9550 - 7ms/step\n", - "step 390/586 - loss: 0.3466 - acc: 0.9551 - 7ms/step\n", - "step 400/586 - loss: 0.3370 - acc: 0.9555 - 7ms/step\n", - "step 410/586 - loss: 0.3841 - acc: 0.9555 - 7ms/step\n", - "step 420/586 - loss: 0.3942 - acc: 0.9552 - 7ms/step\n", - "step 430/586 - loss: 0.3547 - acc: 0.9551 - 7ms/step\n", - "step 440/586 - loss: 0.3170 - acc: 0.9553 - 7ms/step\n", - "step 450/586 - loss: 0.3266 - acc: 0.9556 - 7ms/step\n", - "step 460/586 - loss: 0.3429 - acc: 0.9553 - 7ms/step\n", - "step 470/586 - loss: 0.3164 - acc: 0.9555 - 7ms/step\n", - "step 480/586 - loss: 0.3724 - acc: 0.9555 - 7ms/step\n", - "step 490/586 - loss: 0.3533 - acc: 0.9554 - 7ms/step\n", - "step 500/586 - loss: 0.4149 - acc: 0.9556 - 7ms/step\n", - "step 510/586 - loss: 0.3577 - acc: 0.9552 - 7ms/step\n", - "step 520/586 - loss: 0.3712 - acc: 0.9553 - 7ms/step\n", - "step 530/586 - loss: 0.3233 - acc: 0.9555 - 7ms/step\n", - "step 540/586 - loss: 0.3177 - acc: 0.9556 - 7ms/step\n", - "step 550/586 - loss: 0.3508 - acc: 0.9557 - 7ms/step\n", - "step 560/586 - loss: 0.3778 - acc: 0.9553 - 7ms/step\n", - "step 570/586 - loss: 0.3157 - acc: 0.9552 - 7ms/step\n", - "step 580/586 - loss: 0.3832 - acc: 0.9551 - 7ms/step\n", - "step 586/586 - loss: 0.3516 - acc: 0.9552 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3466 - acc: 0.9781 - 7ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.3740 - acc: 0.8875 - 3ms/step\n", - "step 20/196 - loss: 0.3935 - acc: 0.8922 - 3ms/step\n", - "step 30/196 - loss: 0.5860 - acc: 0.8771 - 3ms/step\n", - "step 40/196 - loss: 0.4778 - acc: 0.8719 - 3ms/step\n", - "step 50/196 - loss: 0.4675 - acc: 0.8669 - 3ms/step\n", - "step 60/196 - loss: 0.3974 - acc: 0.8625 - 3ms/step\n", - "step 70/196 - loss: 0.4264 - acc: 0.8670 - 3ms/step\n", - "step 80/196 - loss: 0.4237 - acc: 0.8668 - 3ms/step\n", - "step 90/196 - loss: 0.5286 - acc: 0.8660 - 3ms/step\n", - "step 100/196 - loss: 0.3980 - acc: 0.8669 - 3ms/step\n", - "step 110/196 - loss: 0.4362 - acc: 0.8676 - 3ms/step\n", - "step 120/196 - loss: 0.3779 - acc: 0.8638 - 3ms/step\n", - "step 130/196 - loss: 0.4090 - acc: 0.8644 - 3ms/step\n", - "step 140/196 - loss: 0.4323 - acc: 0.8652 - 3ms/step\n", - "step 150/196 - loss: 0.4067 - acc: 0.8654 - 3ms/step\n", - "step 160/196 - loss: 0.5107 - acc: 0.8637 - 3ms/step\n", - "step 170/196 - loss: 0.4058 - acc: 0.8649 - 3ms/step\n", - "step 180/196 - loss: 0.3519 - acc: 0.8663 - 3ms/step\n", - "step 190/196 - loss: 0.4454 - acc: 0.8663 - 3ms/step\n", - "step 196/196 - loss: 0.4457 - acc: 0.8672 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.7157 - acc: 0.8725 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 7/10\n", - "step 10/586 - loss: 0.4362 - acc: 0.9437 - 7ms/step\n", - "step 20/586 - loss: 0.3194 - acc: 0.9469 - 7ms/step\n", - "step 30/586 - loss: 0.4111 - acc: 0.9510 - 7ms/step\n", - "step 40/586 - loss: 0.3341 - acc: 0.9531 - 7ms/step\n", - "step 50/586 - loss: 0.3775 - acc: 0.9563 - 7ms/step\n", - "step 60/586 - loss: 0.3455 - acc: 0.9578 - 7ms/step\n", - "step 70/586 - loss: 0.3955 - acc: 0.9563 - 7ms/step\n", - "step 80/586 - loss: 0.3743 - acc: 0.9586 - 7ms/step\n", - "step 90/586 - loss: 0.3200 - acc: 0.9587 - 7ms/step\n", - "step 100/586 - loss: 0.3480 - acc: 0.9578 - 7ms/step\n", - "step 110/586 - loss: 0.3540 - acc: 0.9594 - 7ms/step\n", - "step 120/586 - loss: 0.3137 - acc: 0.9609 - 7ms/step\n", - "step 130/586 - loss: 0.3789 - acc: 0.9606 - 7ms/step\n", - "step 140/586 - loss: 0.3223 - acc: 0.9603 - 7ms/step\n", - "step 150/586 - loss: 0.3147 - acc: 0.9608 - 7ms/step\n", - "step 160/586 - loss: 0.3199 - acc: 0.9617 - 7ms/step\n", - "step 170/586 - loss: 0.3418 - acc: 0.9625 - 7ms/step\n", - "step 180/586 - loss: 0.3225 - acc: 0.9634 - 8ms/step\n", - "step 190/586 - loss: 0.3235 - acc: 0.9645 - 8ms/step\n", - "step 200/586 - loss: 0.3151 - acc: 0.9655 - 8ms/step\n", - "step 210/586 - loss: 0.3149 - acc: 0.9658 - 8ms/step\n", - "step 220/586 - loss: 0.3457 - acc: 0.9659 - 8ms/step\n", - "step 230/586 - loss: 0.3459 - acc: 0.9662 - 8ms/step\n", - "step 240/586 - loss: 0.3166 - acc: 0.9664 - 8ms/step\n", - "step 250/586 - loss: 0.3819 - acc: 0.9661 - 8ms/step\n", - "step 260/586 - loss: 0.3473 - acc: 0.9660 - 8ms/step\n", - "step 270/586 - loss: 0.3214 - acc: 0.9661 - 8ms/step\n", - "step 280/586 - loss: 0.4032 - acc: 0.9660 - 8ms/step\n", - "step 290/586 - loss: 0.3486 - acc: 0.9659 - 8ms/step\n", - "step 300/586 - loss: 0.3309 - acc: 0.9663 - 8ms/step\n", - "step 310/586 - loss: 0.3581 - acc: 0.9664 - 7ms/step\n", - "step 320/586 - loss: 0.4081 - acc: 0.9657 - 7ms/step\n", - "step 330/586 - loss: 0.3550 - acc: 0.9653 - 7ms/step\n", - "step 340/586 - loss: 0.3379 - acc: 0.9657 - 7ms/step\n", - "step 350/586 - loss: 0.3423 - acc: 0.9652 - 7ms/step\n", - "step 360/586 - loss: 0.3774 - acc: 0.9649 - 7ms/step\n", - "step 370/586 - loss: 0.3143 - acc: 0.9651 - 7ms/step\n", - "step 380/586 - loss: 0.3399 - acc: 0.9651 - 7ms/step\n", - "step 390/586 - loss: 0.3416 - acc: 0.9655 - 7ms/step\n", - "step 400/586 - loss: 0.3877 - acc: 0.9652 - 7ms/step\n", - "step 410/586 - loss: 0.4009 - acc: 0.9649 - 7ms/step\n", - "step 420/586 - loss: 0.3149 - acc: 0.9647 - 7ms/step\n", - "step 430/586 - loss: 0.3817 - acc: 0.9646 - 7ms/step\n", - "step 440/586 - loss: 0.3468 - acc: 0.9649 - 7ms/step\n", - "step 450/586 - loss: 0.3474 - acc: 0.9650 - 7ms/step\n", - "step 460/586 - loss: 0.3547 - acc: 0.9649 - 7ms/step\n", - "step 470/586 - loss: 0.3495 - acc: 0.9651 - 7ms/step\n", - "step 480/586 - loss: 0.3674 - acc: 0.9647 - 7ms/step\n", - "step 490/586 - loss: 0.3634 - acc: 0.9647 - 7ms/step\n", - "step 500/586 - loss: 0.3542 - acc: 0.9647 - 7ms/step\n", - "step 510/586 - loss: 0.3150 - acc: 0.9650 - 7ms/step\n", - "step 520/586 - loss: 0.3141 - acc: 0.9652 - 7ms/step\n", - "step 530/586 - loss: 0.3235 - acc: 0.9652 - 7ms/step\n", - "step 540/586 - loss: 0.3867 - acc: 0.9653 - 7ms/step\n", - "step 550/586 - loss: 0.3493 - acc: 0.9655 - 7ms/step\n", - "step 560/586 - loss: 0.4191 - acc: 0.9656 - 7ms/step\n", - "step 570/586 - loss: 0.3169 - acc: 0.9650 - 7ms/step\n", - "step 580/586 - loss: 0.3171 - acc: 0.9649 - 7ms/step\n", - "step 586/586 - loss: 0.3298 - acc: 0.9648 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3139 - acc: 0.9781 - 6ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4102 - acc: 0.8781 - 3ms/step\n", - "step 20/196 - loss: 0.3831 - acc: 0.8906 - 3ms/step\n", - "step 30/196 - loss: 0.5540 - acc: 0.8802 - 3ms/step\n", - "step 40/196 - loss: 0.5060 - acc: 0.8727 - 2ms/step\n", - "step 50/196 - loss: 0.4351 - acc: 0.8750 - 2ms/step\n", - "step 60/196 - loss: 0.3830 - acc: 0.8698 - 2ms/step\n", - "step 70/196 - loss: 0.4603 - acc: 0.8723 - 2ms/step\n", - "step 80/196 - loss: 0.4188 - acc: 0.8703 - 2ms/step\n", - "step 90/196 - loss: 0.5685 - acc: 0.8691 - 2ms/step\n", - "step 100/196 - loss: 0.4086 - acc: 0.8719 - 2ms/step\n", - "step 110/196 - loss: 0.4628 - acc: 0.8722 - 3ms/step\n", - "step 120/196 - loss: 0.3791 - acc: 0.8674 - 3ms/step\n", - "step 130/196 - loss: 0.4087 - acc: 0.8673 - 3ms/step\n", - "step 140/196 - loss: 0.4109 - acc: 0.8688 - 3ms/step\n", - "step 150/196 - loss: 0.4144 - acc: 0.8688 - 3ms/step\n", - "step 160/196 - loss: 0.5291 - acc: 0.8666 - 3ms/step\n", - "step 170/196 - loss: 0.4071 - acc: 0.8678 - 3ms/step\n", - "step 180/196 - loss: 0.3402 - acc: 0.8703 - 3ms/step\n", - "step 190/196 - loss: 0.4466 - acc: 0.8707 - 3ms/step\n", - "step 196/196 - loss: 0.4286 - acc: 0.8712 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.7192 - acc: 0.8723 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 8/10\n", - "step 10/586 - loss: 0.3689 - acc: 0.9531 - 7ms/step\n", - "step 20/586 - loss: 0.3800 - acc: 0.9531 - 7ms/step\n", - "step 30/586 - loss: 0.3609 - acc: 0.9583 - 7ms/step\n", - "step 40/586 - loss: 0.3177 - acc: 0.9586 - 7ms/step\n", - "step 50/586 - loss: 0.4016 - acc: 0.9594 - 7ms/step\n", - "step 60/586 - loss: 0.3537 - acc: 0.9609 - 8ms/step\n", - "step 70/586 - loss: 0.3203 - acc: 0.9616 - 7ms/step\n", - "step 80/586 - loss: 0.4411 - acc: 0.9609 - 7ms/step\n", - "step 90/586 - loss: 0.3150 - acc: 0.9639 - 7ms/step\n", - "step 100/586 - loss: 0.3471 - acc: 0.9641 - 7ms/step\n", - "step 110/586 - loss: 0.3501 - acc: 0.9642 - 7ms/step\n", - "step 120/586 - loss: 0.3169 - acc: 0.9648 - 7ms/step\n", - "step 130/586 - loss: 0.3198 - acc: 0.9656 - 7ms/step\n", - "step 140/586 - loss: 0.3776 - acc: 0.9658 - 7ms/step\n", - "step 150/586 - loss: 0.3153 - acc: 0.9660 - 7ms/step\n", - "step 160/586 - loss: 0.3440 - acc: 0.9664 - 7ms/step\n", - "step 170/586 - loss: 0.3279 - acc: 0.9675 - 7ms/step\n", - "step 180/586 - loss: 0.3984 - acc: 0.9674 - 7ms/step\n", - "step 190/586 - loss: 0.3500 - acc: 0.9666 - 7ms/step\n", - "step 200/586 - loss: 0.3219 - acc: 0.9675 - 7ms/step\n", - "step 210/586 - loss: 0.3145 - acc: 0.9683 - 7ms/step\n", - "step 220/586 - loss: 0.3149 - acc: 0.9688 - 7ms/step\n", - "step 230/586 - loss: 0.3151 - acc: 0.9688 - 7ms/step\n", - "step 240/586 - loss: 0.3338 - acc: 0.9689 - 7ms/step\n", - "step 250/586 - loss: 0.3511 - acc: 0.9692 - 7ms/step\n", - "step 260/586 - loss: 0.3172 - acc: 0.9692 - 7ms/step\n", - "step 270/586 - loss: 0.3563 - acc: 0.9691 - 7ms/step\n", - "step 280/586 - loss: 0.3349 - acc: 0.9693 - 7ms/step\n", - "step 290/586 - loss: 0.3152 - acc: 0.9692 - 7ms/step\n", - "step 300/586 - loss: 0.3159 - acc: 0.9694 - 7ms/step\n", - "step 310/586 - loss: 0.3461 - acc: 0.9695 - 7ms/step\n", - "step 320/586 - loss: 0.3141 - acc: 0.9696 - 7ms/step\n", - "step 330/586 - loss: 0.3141 - acc: 0.9694 - 7ms/step\n", - "step 340/586 - loss: 0.3457 - acc: 0.9698 - 7ms/step\n", - "step 350/586 - loss: 0.4385 - acc: 0.9690 - 7ms/step\n", - "step 360/586 - loss: 0.3157 - acc: 0.9684 - 7ms/step\n", - "step 370/586 - loss: 0.3917 - acc: 0.9679 - 7ms/step\n", - "step 380/586 - loss: 0.3459 - acc: 0.9679 - 7ms/step\n", - "step 390/586 - loss: 0.3501 - acc: 0.9683 - 7ms/step\n", - "step 400/586 - loss: 0.3754 - acc: 0.9684 - 7ms/step\n", - "step 410/586 - loss: 0.3550 - acc: 0.9685 - 7ms/step\n", - "step 420/586 - loss: 0.3843 - acc: 0.9689 - 7ms/step\n", - "step 430/586 - loss: 0.3151 - acc: 0.9691 - 7ms/step\n", - "step 440/586 - loss: 0.3450 - acc: 0.9692 - 7ms/step\n", - "step 450/586 - loss: 0.3548 - acc: 0.9696 - 7ms/step\n", - "step 460/586 - loss: 0.3773 - acc: 0.9695 - 7ms/step\n", - "step 470/586 - loss: 0.3140 - acc: 0.9694 - 7ms/step\n", - "step 480/586 - loss: 0.3399 - acc: 0.9692 - 7ms/step\n", - "step 490/586 - loss: 0.3873 - acc: 0.9690 - 7ms/step\n", - "step 500/586 - loss: 0.3411 - acc: 0.9691 - 7ms/step\n", - "step 510/586 - loss: 0.3454 - acc: 0.9691 - 7ms/step\n", - "step 520/586 - loss: 0.3436 - acc: 0.9690 - 7ms/step\n", - "step 530/586 - loss: 0.3766 - acc: 0.9688 - 7ms/step\n", - "step 540/586 - loss: 0.3471 - acc: 0.9685 - 7ms/step\n", - "step 550/586 - loss: 0.3464 - acc: 0.9688 - 7ms/step\n", - "step 560/586 - loss: 0.3193 - acc: 0.9689 - 7ms/step\n", - "step 570/586 - loss: 0.3206 - acc: 0.9688 - 7ms/step\n", - "step 580/586 - loss: 0.4221 - acc: 0.9688 - 7ms/step\n", - "step 586/586 - loss: 0.4137 - acc: 0.9687 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3481 - acc: 0.9801 - 7ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4021 - acc: 0.8750 - 3ms/step\n", - "step 20/196 - loss: 0.3832 - acc: 0.8938 - 3ms/step\n", - "step 30/196 - loss: 0.5667 - acc: 0.8802 - 3ms/step\n", - "step 40/196 - loss: 0.4844 - acc: 0.8727 - 3ms/step\n", - "step 50/196 - loss: 0.4485 - acc: 0.8712 - 3ms/step\n", - "step 60/196 - loss: 0.3879 - acc: 0.8661 - 3ms/step\n", - "step 70/196 - loss: 0.4494 - acc: 0.8719 - 3ms/step\n", - "step 80/196 - loss: 0.4179 - acc: 0.8707 - 3ms/step\n", - "step 90/196 - loss: 0.5252 - acc: 0.8712 - 3ms/step\n", - "step 100/196 - loss: 0.3908 - acc: 0.8728 - 3ms/step\n", - "step 110/196 - loss: 0.4374 - acc: 0.8730 - 2ms/step\n", - "step 120/196 - loss: 0.3779 - acc: 0.8685 - 2ms/step\n", - "step 130/196 - loss: 0.4083 - acc: 0.8680 - 2ms/step\n", - "step 140/196 - loss: 0.4196 - acc: 0.8688 - 2ms/step\n", - "step 150/196 - loss: 0.3966 - acc: 0.8683 - 2ms/step\n", - "step 160/196 - loss: 0.5057 - acc: 0.8670 - 2ms/step\n", - "step 170/196 - loss: 0.3764 - acc: 0.8676 - 2ms/step\n", - "step 180/196 - loss: 0.3452 - acc: 0.8693 - 2ms/step\n", - "step 190/196 - loss: 0.4252 - acc: 0.8689 - 2ms/step\n", - "step 196/196 - loss: 0.4172 - acc: 0.8696 - 2ms/step\n", + "step 196/196 [==============================] - loss: 0.7283 - acc: 0.8722 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 9/10\n", - "step 10/586 - loss: 0.3192 - acc: 0.9875 - 7ms/step\n", - "step 20/586 - loss: 0.3457 - acc: 0.9844 - 8ms/step\n", - "step 30/586 - loss: 0.3765 - acc: 0.9771 - 7ms/step\n", - "step 40/586 - loss: 0.3740 - acc: 0.9680 - 7ms/step\n", - "step 50/586 - loss: 0.3542 - acc: 0.9656 - 7ms/step\n", - "step 60/586 - loss: 0.3400 - acc: 0.9625 - 7ms/step\n", - "step 70/586 - loss: 0.3535 - acc: 0.9625 - 7ms/step\n", - "step 80/586 - loss: 0.3456 - acc: 0.9645 - 7ms/step\n", - "step 90/586 - loss: 0.3141 - acc: 0.9663 - 7ms/step\n", - "step 100/586 - loss: 0.3465 - acc: 0.9650 - 7ms/step\n", - "step 110/586 - loss: 0.3315 - acc: 0.9645 - 7ms/step\n", - "step 120/586 - loss: 0.3145 - acc: 0.9659 - 7ms/step\n", - "step 130/586 - loss: 0.3475 - acc: 0.9668 - 7ms/step\n", - "step 140/586 - loss: 0.3171 - acc: 0.9683 - 7ms/step\n", - "step 150/586 - loss: 0.3462 - acc: 0.9681 - 7ms/step\n", - "step 160/586 - loss: 0.3492 - acc: 0.9682 - 7ms/step\n", - "step 170/586 - loss: 0.3475 - acc: 0.9689 - 7ms/step\n", - "step 180/586 - loss: 0.3466 - acc: 0.9694 - 7ms/step\n", - "step 190/586 - loss: 0.4103 - acc: 0.9696 - 7ms/step\n", - "step 200/586 - loss: 0.3672 - acc: 0.9700 - 7ms/step\n", - "step 210/586 - loss: 0.4100 - acc: 0.9695 - 7ms/step\n", - "step 220/586 - loss: 0.4084 - acc: 0.9699 - 7ms/step\n", - "step 230/586 - loss: 0.3141 - acc: 0.9707 - 7ms/step\n", - "step 240/586 - loss: 0.3450 - acc: 0.9708 - 7ms/step\n", - "step 250/586 - loss: 0.3462 - acc: 0.9705 - 7ms/step\n", - "step 260/586 - loss: 0.3178 - acc: 0.9706 - 7ms/step\n", - "step 270/586 - loss: 0.3451 - acc: 0.9703 - 7ms/step\n", - "step 280/586 - loss: 0.3493 - acc: 0.9705 - 7ms/step\n", - "step 290/586 - loss: 0.3174 - acc: 0.9711 - 7ms/step\n", - "step 300/586 - loss: 0.3171 - acc: 0.9716 - 7ms/step\n", - "step 310/586 - loss: 0.3478 - acc: 0.9720 - 7ms/step\n", - "step 320/586 - loss: 0.3220 - acc: 0.9723 - 7ms/step\n", - "step 330/586 - loss: 0.3139 - acc: 0.9724 - 7ms/step\n", - "step 340/586 - loss: 0.3137 - acc: 0.9730 - 7ms/step\n", - "step 350/586 - loss: 0.4082 - acc: 0.9728 - 7ms/step\n", - "step 360/586 - loss: 0.3447 - acc: 0.9727 - 7ms/step\n", - "step 370/586 - loss: 0.3136 - acc: 0.9728 - 7ms/step\n", - "step 380/586 - loss: 0.3284 - acc: 0.9728 - 7ms/step\n", - "step 390/586 - loss: 0.4076 - acc: 0.9726 - 7ms/step\n", - "step 400/586 - loss: 0.3646 - acc: 0.9726 - 7ms/step\n", - "step 410/586 - loss: 0.3137 - acc: 0.9723 - 7ms/step\n", - "step 420/586 - loss: 0.3452 - acc: 0.9724 - 7ms/step\n", - "step 430/586 - loss: 0.3210 - acc: 0.9720 - 7ms/step\n", - "step 440/586 - loss: 0.3764 - acc: 0.9719 - 7ms/step\n", - "step 450/586 - loss: 0.3449 - acc: 0.9721 - 7ms/step\n", - "step 460/586 - loss: 0.3808 - acc: 0.9724 - 7ms/step\n", - "step 470/586 - loss: 0.3767 - acc: 0.9723 - 7ms/step\n", - "step 480/586 - loss: 0.3582 - acc: 0.9720 - 7ms/step\n", - "step 490/586 - loss: 0.4074 - acc: 0.9721 - 7ms/step\n", - "step 500/586 - loss: 0.3281 - acc: 0.9724 - 7ms/step\n", - "step 510/586 - loss: 0.3197 - acc: 0.9725 - 7ms/step\n", - "step 520/586 - loss: 0.3449 - acc: 0.9725 - 7ms/step\n", - "step 530/586 - loss: 0.3772 - acc: 0.9723 - 7ms/step\n", - "step 540/586 - loss: 0.3460 - acc: 0.9723 - 7ms/step\n", - "step 550/586 - loss: 0.3758 - acc: 0.9719 - 7ms/step\n", - "step 560/586 - loss: 0.3837 - acc: 0.9720 - 7ms/step\n", - "step 570/586 - loss: 0.3185 - acc: 0.9718 - 7ms/step\n", - "step 580/586 - loss: 0.3173 - acc: 0.9720 - 7ms/step\n", - "step 586/586 - loss: 0.3142 - acc: 0.9721 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3217 - acc: 0.9802 - 7ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4118 - acc: 0.8562 - 3ms/step\n", - "step 20/196 - loss: 0.4136 - acc: 0.8688 - 3ms/step\n", - "step 30/196 - loss: 0.5431 - acc: 0.8729 - 3ms/step\n", - "step 40/196 - loss: 0.4878 - acc: 0.8641 - 2ms/step\n", - "step 50/196 - loss: 0.4139 - acc: 0.8675 - 2ms/step\n", - "step 60/196 - loss: 0.3872 - acc: 0.8646 - 2ms/step\n", - "step 70/196 - loss: 0.4269 - acc: 0.8692 - 2ms/step\n", - "step 80/196 - loss: 0.4665 - acc: 0.8668 - 2ms/step\n", - "step 90/196 - loss: 0.5964 - acc: 0.8670 - 2ms/step\n", - "step 100/196 - loss: 0.4225 - acc: 0.8709 - 2ms/step\n", - "step 110/196 - loss: 0.4720 - acc: 0.8696 - 2ms/step\n", - "step 120/196 - loss: 0.3814 - acc: 0.8635 - 2ms/step\n", - "step 130/196 - loss: 0.4242 - acc: 0.8635 - 3ms/step\n", - "step 140/196 - loss: 0.3902 - acc: 0.8661 - 3ms/step\n", - "step 150/196 - loss: 0.4303 - acc: 0.8648 - 3ms/step\n", - "step 160/196 - loss: 0.5004 - acc: 0.8633 - 3ms/step\n", - "step 170/196 - loss: 0.4446 - acc: 0.8632 - 3ms/step\n", - "step 180/196 - loss: 0.3417 - acc: 0.8656 - 3ms/step\n", - "step 190/196 - loss: 0.4667 - acc: 0.8660 - 3ms/step\n", - "step 196/196 - loss: 0.4134 - acc: 0.8664 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.6697 - acc: 0.8629 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 10/10\n", - "step 10/586 - loss: 0.3144 - acc: 0.9781 - 7ms/step\n", - "step 20/586 - loss: 0.3819 - acc: 0.9719 - 7ms/step\n", - "step 30/586 - loss: 0.3147 - acc: 0.9698 - 7ms/step\n", - "step 40/586 - loss: 0.3139 - acc: 0.9727 - 7ms/step\n", - "step 50/586 - loss: 0.3788 - acc: 0.9738 - 7ms/step\n", - "step 60/586 - loss: 0.3472 - acc: 0.9724 - 7ms/step\n", - "step 70/586 - loss: 0.3139 - acc: 0.9714 - 7ms/step\n", - "step 80/586 - loss: 0.3453 - acc: 0.9727 - 7ms/step\n", - "step 90/586 - loss: 0.3769 - acc: 0.9729 - 7ms/step\n", - "step 100/586 - loss: 0.3460 - acc: 0.9734 - 7ms/step\n", - "step 110/586 - loss: 0.3137 - acc: 0.9727 - 7ms/step\n", - "step 120/586 - loss: 0.3137 - acc: 0.9721 - 7ms/step\n", - "step 130/586 - loss: 0.3458 - acc: 0.9724 - 7ms/step\n", - "step 140/586 - loss: 0.3453 - acc: 0.9732 - 7ms/step\n", - "step 150/586 - loss: 0.3457 - acc: 0.9729 - 7ms/step\n", - "step 160/586 - loss: 0.3145 - acc: 0.9740 - 7ms/step\n", - "step 170/586 - loss: 0.3614 - acc: 0.9732 - 7ms/step\n", - "step 180/586 - loss: 0.3550 - acc: 0.9731 - 7ms/step\n", - "step 190/586 - loss: 0.3135 - acc: 0.9735 - 7ms/step\n", - "step 200/586 - loss: 0.3638 - acc: 0.9739 - 7ms/step\n", - "step 210/586 - loss: 0.3447 - acc: 0.9737 - 7ms/step\n", - "step 220/586 - loss: 0.3136 - acc: 0.9734 - 7ms/step\n", - "step 230/586 - loss: 0.3480 - acc: 0.9735 - 7ms/step\n", - "step 240/586 - loss: 0.3144 - acc: 0.9734 - 7ms/step\n", - "step 250/586 - loss: 0.3147 - acc: 0.9740 - 7ms/step\n", - "step 260/586 - loss: 0.3135 - acc: 0.9742 - 7ms/step\n", - "step 270/586 - loss: 0.3768 - acc: 0.9748 - 7ms/step\n", - "step 280/586 - loss: 0.3455 - acc: 0.9749 - 7ms/step\n", - "step 290/586 - loss: 0.3147 - acc: 0.9748 - 7ms/step\n", - "step 300/586 - loss: 0.3765 - acc: 0.9745 - 7ms/step\n", - "step 310/586 - loss: 0.3761 - acc: 0.9742 - 7ms/step\n", - "step 320/586 - loss: 0.3487 - acc: 0.9739 - 7ms/step\n", - "step 330/586 - loss: 0.3621 - acc: 0.9739 - 7ms/step\n", - "step 340/586 - loss: 0.3145 - acc: 0.9738 - 7ms/step\n", - "step 350/586 - loss: 0.3135 - acc: 0.9738 - 7ms/step\n", - "step 360/586 - loss: 0.3454 - acc: 0.9740 - 7ms/step\n", - "step 370/586 - loss: 0.3145 - acc: 0.9744 - 7ms/step\n", - "step 380/586 - loss: 0.3454 - acc: 0.9745 - 7ms/step\n", - "step 390/586 - loss: 0.3462 - acc: 0.9747 - 7ms/step\n", - "step 400/586 - loss: 0.3152 - acc: 0.9750 - 7ms/step\n", - "step 410/586 - loss: 0.3473 - acc: 0.9753 - 7ms/step\n", - "step 420/586 - loss: 0.3449 - acc: 0.9754 - 7ms/step\n", - "step 430/586 - loss: 0.3154 - acc: 0.9757 - 7ms/step\n", - "step 440/586 - loss: 0.3457 - acc: 0.9759 - 7ms/step\n", - "step 450/586 - loss: 0.3457 - acc: 0.9757 - 7ms/step\n", - "step 460/586 - loss: 0.3447 - acc: 0.9757 - 7ms/step\n", - "step 470/586 - loss: 0.3137 - acc: 0.9757 - 7ms/step\n", - "step 480/586 - loss: 0.3139 - acc: 0.9759 - 7ms/step\n", - "step 490/586 - loss: 0.3473 - acc: 0.9760 - 7ms/step\n", - "step 500/586 - loss: 0.3155 - acc: 0.9759 - 7ms/step\n", - "step 510/586 - loss: 0.3760 - acc: 0.9757 - 7ms/step\n", - "step 520/586 - loss: 0.3452 - acc: 0.9755 - 7ms/step\n", - "step 530/586 - loss: 0.3139 - acc: 0.9756 - 7ms/step\n", - "step 540/586 - loss: 0.3139 - acc: 0.9756 - 7ms/step\n", - "step 550/586 - loss: 0.3143 - acc: 0.9757 - 7ms/step\n", - "step 560/586 - loss: 0.3144 - acc: 0.9759 - 7ms/step\n", - "step 570/586 - loss: 0.3450 - acc: 0.9759 - 7ms/step\n", - "step 580/586 - loss: 0.3245 - acc: 0.9758 - 7ms/step\n", - "step 586/586 - loss: 0.3829 - acc: 0.9756 - 7ms/step\n", + "step 586/586 [==============================] - loss: 0.3466 - acc: 0.9807 - 7ms/step \n", "Eval begin...\n", - "step 10/196 - loss: 0.4100 - acc: 0.8531 - 5ms/step\n", - "step 20/196 - loss: 0.4061 - acc: 0.8703 - 4ms/step\n", - "step 30/196 - loss: 0.5566 - acc: 0.8719 - 3ms/step\n", - "step 40/196 - loss: 0.4805 - acc: 0.8656 - 3ms/step\n", - "step 50/196 - loss: 0.4235 - acc: 0.8662 - 3ms/step\n", - "step 60/196 - loss: 0.4023 - acc: 0.8620 - 3ms/step\n", - "step 70/196 - loss: 0.4327 - acc: 0.8656 - 3ms/step\n", - "step 80/196 - loss: 0.4856 - acc: 0.8625 - 3ms/step\n", - "step 90/196 - loss: 0.5713 - acc: 0.8639 - 3ms/step\n", - "step 100/196 - loss: 0.3963 - acc: 0.8678 - 3ms/step\n", - "step 110/196 - loss: 0.4678 - acc: 0.8676 - 3ms/step\n", - "step 120/196 - loss: 0.4025 - acc: 0.8625 - 3ms/step\n", - "step 130/196 - loss: 0.4336 - acc: 0.8627 - 3ms/step\n", - "step 140/196 - loss: 0.3946 - acc: 0.8652 - 3ms/step\n", - "step 150/196 - loss: 0.4038 - acc: 0.8646 - 3ms/step\n", - "step 160/196 - loss: 0.5087 - acc: 0.8633 - 3ms/step\n", - "step 170/196 - loss: 0.4656 - acc: 0.8638 - 3ms/step\n", - "step 180/196 - loss: 0.3433 - acc: 0.8660 - 3ms/step\n", - "step 190/196 - loss: 0.4656 - acc: 0.8663 - 3ms/step\n", - "step 196/196 - loss: 0.4132 - acc: 0.8672 - 3ms/step\n", + "step 196/196 [==============================] - loss: 0.7435 - acc: 0.8726 - 2ms/step \n", "Eval samples: 6250\n" ] } @@ -1306,13 +533,7 @@ "\r\n", " def __len__(self):\r\n", " return len(self.data)\r\n", - "\r\n", - "\r\n", - "# 指定训练设备\r\n", - "device = pd.set_device('gpu') # 可选:cpu\r\n", - "\r\n", - "# 开启动态图模式\r\n", - "pd.disable_static(device)\r\n", + " \r\n", "\r\n", "# 定义输入格式\r\n", "input_form = pd.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", @@ -1327,7 +548,7 @@ "eval_length = int(len(train_x) * 1/4)\r\n", "model.fit(train_data=DataReader(train_x[:-eval_length], train_y[:-eval_length], length),\r\n", " eval_data=DataReader(train_x[-eval_length:], train_y[-eval_length:], length),\r\n", - " batch_size=32, epochs=10)" + " batch_size=32, epochs=10, verbose=1)" ] }, { @@ -1341,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": { "collapsed": false }, @@ -1351,85 +572,7 @@ "output_type": "stream", "text": [ "Eval begin...\n", - "step 10/782 - loss: 0.4515 - acc: 0.8531 - 3ms/step\n", - "step 20/782 - loss: 0.5053 - acc: 0.8656 - 3ms/step\n", - "step 30/782 - loss: 0.4896 - acc: 0.8406 - 3ms/step\n", - "step 40/782 - loss: 0.3849 - acc: 0.8469 - 3ms/step\n", - "step 50/782 - loss: 0.5705 - acc: 0.8331 - 3ms/step\n", - "step 60/782 - loss: 0.3480 - acc: 0.8370 - 3ms/step\n", - "step 70/782 - loss: 0.3403 - acc: 0.8460 - 3ms/step\n", - "step 80/782 - loss: 0.3370 - acc: 0.8473 - 3ms/step\n", - "step 90/782 - loss: 0.5180 - acc: 0.8462 - 3ms/step\n", - "step 100/782 - loss: 0.4266 - acc: 0.8481 - 3ms/step\n", - "step 110/782 - loss: 0.4605 - acc: 0.8486 - 3ms/step\n", - "step 120/782 - loss: 0.3836 - acc: 0.8477 - 3ms/step\n", - "step 130/782 - loss: 0.4657 - acc: 0.8474 - 3ms/step\n", - "step 140/782 - loss: 0.4203 - acc: 0.8462 - 3ms/step\n", - "step 150/782 - loss: 0.4735 - acc: 0.8408 - 3ms/step\n", - "step 160/782 - loss: 0.4959 - acc: 0.8412 - 3ms/step\n", - "step 170/782 - loss: 0.3490 - acc: 0.8419 - 3ms/step\n", - "step 180/782 - loss: 0.6037 - acc: 0.8415 - 3ms/step\n", - "step 190/782 - loss: 0.4110 - acc: 0.8416 - 3ms/step\n", - "step 200/782 - loss: 0.5318 - acc: 0.8430 - 3ms/step\n", - "step 210/782 - loss: 0.4332 - acc: 0.8449 - 3ms/step\n", - "step 220/782 - loss: 0.6212 - acc: 0.8447 - 3ms/step\n", - "step 230/782 - loss: 0.4884 - acc: 0.8443 - 3ms/step\n", - "step 240/782 - loss: 0.3646 - acc: 0.8434 - 3ms/step\n", - "step 250/782 - loss: 0.4735 - acc: 0.8446 - 3ms/step\n", - "step 260/782 - loss: 0.4272 - acc: 0.8460 - 3ms/step\n", - "step 270/782 - loss: 0.5258 - acc: 0.8453 - 3ms/step\n", - "step 280/782 - loss: 0.4614 - acc: 0.8449 - 3ms/step\n", - "step 290/782 - loss: 0.4773 - acc: 0.8454 - 3ms/step\n", - "step 300/782 - loss: 0.5187 - acc: 0.8441 - 3ms/step\n", - "step 310/782 - loss: 0.4952 - acc: 0.8431 - 3ms/step\n", - "step 320/782 - loss: 0.3959 - acc: 0.8435 - 3ms/step\n", - "step 330/782 - loss: 0.4840 - acc: 0.8437 - 3ms/step\n", - "step 340/782 - loss: 0.3650 - acc: 0.8441 - 3ms/step\n", - "step 350/782 - loss: 0.4842 - acc: 0.8450 - 3ms/step\n", - "step 360/782 - loss: 0.4866 - acc: 0.8444 - 3ms/step\n", - "step 370/782 - loss: 0.4882 - acc: 0.8454 - 3ms/step\n", - "step 380/782 - loss: 0.4428 - acc: 0.8434 - 3ms/step\n", - "step 390/782 - loss: 0.4084 - acc: 0.8430 - 3ms/step\n", - "step 400/782 - loss: 0.4584 - acc: 0.8433 - 3ms/step\n", - "step 410/782 - loss: 0.5239 - acc: 0.8442 - 3ms/step\n", - "step 420/782 - loss: 0.4221 - acc: 0.8453 - 3ms/step\n", - "step 430/782 - loss: 0.3200 - acc: 0.8466 - 3ms/step\n", - "step 440/782 - loss: 0.3503 - acc: 0.8479 - 3ms/step\n", - "step 450/782 - loss: 0.4750 - acc: 0.8488 - 3ms/step\n", - "step 460/782 - loss: 0.4753 - acc: 0.8505 - 3ms/step\n", - "step 470/782 - loss: 0.5096 - acc: 0.8504 - 3ms/step\n", - "step 480/782 - loss: 0.4834 - acc: 0.8513 - 3ms/step\n", - "step 490/782 - loss: 0.3860 - acc: 0.8527 - 3ms/step\n", - "step 500/782 - loss: 0.5332 - acc: 0.8533 - 3ms/step\n", - "step 510/782 - loss: 0.4014 - acc: 0.8533 - 3ms/step\n", - "step 520/782 - loss: 0.4066 - acc: 0.8547 - 3ms/step\n", - "step 530/782 - loss: 0.4554 - acc: 0.8557 - 3ms/step\n", - "step 540/782 - loss: 0.5141 - acc: 0.8560 - 3ms/step\n", - "step 550/782 - loss: 0.4621 - acc: 0.8568 - 3ms/step\n", - "step 560/782 - loss: 0.4383 - acc: 0.8576 - 3ms/step\n", - "step 570/782 - loss: 0.3677 - acc: 0.8584 - 3ms/step\n", - "step 580/782 - loss: 0.5716 - acc: 0.8588 - 3ms/step\n", - "step 590/782 - loss: 0.4613 - acc: 0.8596 - 3ms/step\n", - "step 600/782 - loss: 0.4694 - acc: 0.8602 - 3ms/step\n", - "step 610/782 - loss: 0.3561 - acc: 0.8609 - 3ms/step\n", - "step 620/782 - loss: 0.4349 - acc: 0.8608 - 3ms/step\n", - "step 630/782 - loss: 0.4117 - acc: 0.8618 - 3ms/step\n", - "step 640/782 - loss: 0.3703 - acc: 0.8621 - 3ms/step\n", - "step 650/782 - loss: 0.3898 - acc: 0.8623 - 3ms/step\n", - "step 660/782 - loss: 0.4767 - acc: 0.8625 - 3ms/step\n", - "step 670/782 - loss: 0.4580 - acc: 0.8626 - 3ms/step\n", - "step 680/782 - loss: 0.4189 - acc: 0.8622 - 3ms/step\n", - "step 690/782 - loss: 0.4569 - acc: 0.8622 - 3ms/step\n", - "step 700/782 - loss: 0.3807 - acc: 0.8627 - 3ms/step\n", - "step 710/782 - loss: 0.4707 - acc: 0.8632 - 3ms/step\n", - "step 720/782 - loss: 0.3709 - acc: 0.8633 - 3ms/step\n", - "step 730/782 - loss: 0.4519 - acc: 0.8643 - 3ms/step\n", - "step 740/782 - loss: 0.4227 - acc: 0.8651 - 3ms/step\n", - "step 750/782 - loss: 0.4386 - acc: 0.8651 - 3ms/step\n", - "step 760/782 - loss: 0.3844 - acc: 0.8653 - 3ms/step\n", - "step 770/782 - loss: 0.3988 - acc: 0.8657 - 3ms/step\n", - "step 780/782 - loss: 0.3374 - acc: 0.8662 - 3ms/step\n", - "step 782/782 - loss: 0.4368 - acc: 0.8664 - 3ms/step\n", + "step 782/782 [==============================] - loss: 0.4383 - acc: 0.8644 - 2ms/step \n", "Eval samples: 25000\n", "Predict begin...\n", "step 10/10 [==============================] - 2ms/step \n", @@ -1449,7 +592,7 @@ ], "source": [ "# 评估\r\n", - "model.evaluate(eval_data=DataReader(test_x, test_y, length), batch_size=32)\r\n", + "model.evaluate(eval_data=DataReader(test_x, test_y, length), batch_size=32, verbose=1)\r\n", "\r\n", "# 预测\r\n", "true_y = test_y[100:105] + test_y[-110:-105]\r\n", From 5c513869a65bab435cc0fbeaac13c2ce918c3cc9 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 12 Dec 2020 16:20:11 +0800 Subject: [PATCH 03/14] Delete pretrained_word_embeddings.ipynb --- .../pretrained_word_embeddings.ipynb | 627 ------------------ 1 file changed, 627 deletions(-) delete mode 100644 paddle2.0_docs/pretrained_word_embeddings.ipynb diff --git a/paddle2.0_docs/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings.ipynb deleted file mode 100644 index 461892aa..00000000 --- a/paddle2.0_docs/pretrained_word_embeddings.ipynb +++ /dev/null @@ -1,627 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "# 使用预训练的词向量\n", - "\n", - "Author: [Dongyang Yan](623320480@qq.com, github.com/fiyen )\n", - "\n", - "Data created: 2020/11/23\n", - "\n", - "Last modified: 2020/11/24\n", - "\n", - "Description: Tutorial to classify Imdb data using pre-trained word embeddings in paddlepaddle 2.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 摘要\n", - "\n", - "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imbd将直接调用自飞桨2.0,同时,\n", - "利用预训练的词向量([GloVe embedding](http://nlp.stanford.edu/projects/glove/))完成任务。" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 环境设置" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "import paddle as pd\r\n", - "from paddle.io import Dataset\r\n", - "import numpy as np\r\n", - "import paddle.text as pt\r\n", - "import random" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 用飞桨2.0调用Imdb数据集\n", - "由于飞桨2.0提供了经过处理的Imdb数据集,我们可以方便地调用所需要的数据实例,省去了数据预处理的麻烦。目前,飞桨2.0以及内置的高质量\n", - "数据集包括Conll05st、Imdb、Imikolov、Movielens、HCIHousing、WMT14和WMT16等,未来还将提供更多常用数据集的调用接口。" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Cache file /home/aistudio/.cache/paddle/dataset/imdb/imdb%2FaclImdb_v1.tar.gz not found, downloading https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz \n", - "Begin to download\n", - "\n", - "Download finished\n" - ] - } - ], - "source": [ - "imdb_train = pt.Imdb(mode='train', cutoff=150)\r\n", - "imdb_test = pt.Imdb(mode='test', cutoff=150)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "调用Imdb得到的是经过编码的内容。每个样本表示一个文档,以list的形式储存,list中的每个元素都由一个数字表示,对应文档相应位置的某个单词,\n", - "而单词和数字编码是一一对应的。其对应关系可以通过imdb_train.word_idx查看。我们可以检查一下以上生成的数据内容:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "训练集样本数量: 25000; 测试集样本数量: 25000\n", - "样本标签: {0, 1}\n", - "样本字典: [(b'the', 0), (b'and', 1), (b'a', 2), (b'of', 3), (b'to', 4), (b'is', 5), (b'in', 6), (b'it', 7), (b'i', 8), (b'this', 9)]\n", - "单个样本: [5146, 43, 71, 6, 1092, 14, 0, 878, 130, 151, 5146, 18, 281, 747, 0, 5146, 3, 5146, 2165, 37, 5146, 46, 5, 71, 4089, 377, 162, 46, 5, 32, 1287, 300, 35, 203, 2136, 565, 14, 2, 253, 26, 146, 61, 372, 1, 615, 5146, 5, 30, 0, 50, 3290, 6, 2148, 14, 0, 5146, 11, 17, 451, 24, 4, 127, 10, 0, 878, 130, 43, 2, 50, 5146, 751, 5146, 5, 2, 221, 3727, 6, 9, 1167, 373, 9, 5, 5146, 7, 5, 1343, 13, 2, 5146, 1, 250, 7, 98, 4270, 56, 2316, 0, 928, 11, 11, 9, 16, 5, 5146, 5146, 6, 50, 69, 27, 280, 27, 108, 1045, 0, 2633, 4177, 3180, 17, 1675, 1, 2571]\n", - "最小样本长度: 10;最大样本长度: 2469\n" - ] - } - ], - "source": [ - "print(\"训练集样本数量: %d; 测试集样本数量: %d\" % (len(imdb_train), len(imdb_test)))\r\n", - "print(f\"样本标签: {set(imdb_train.labels)}\")\r\n", - "print(f\"样本字典: {list(imdb_train.word_idx.items())[:10]}\")\r\n", - "print(f\"单个样本: {imdb_train.docs[0]}\")\r\n", - "print(f\"最小样本长度: {min([len(x) for x in imdb_train.docs])};最大样本长度: {max([len(x) for x in imdb_train.docs])}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "以上参数中,cutoff定义了构建词典的截止大小,即数据集中出现频率在cutoff以下的不予考虑;mode定义了返回的数据用于何种用途(test: \n", - "测试集,train: 训练集)。对于训练集,我们将数据的顺序打乱,以优化将要进行的分类模型训练的效果。" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "shuffle_index = list(range(len(imdb_train)))\r\n", - "random.shuffle(shuffle_index)\r\n", - "train_x = [imdb_train.docs[i] for i in shuffle_index]\r\n", - "train_y = [imdb_train.labels[i] for i in shuffle_index]\r\n", - "\r\n", - "test_x = imdb_test.docs\r\n", - "test_y = imdb_test.labels" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "从样本长度上可以看到,每个样本的长度是不相同的。然而,在模型的训练过程中,需要保证每个样本的长度相同,以便于构造矩阵进行批量运算。\n", - "因此,我们需要先对所有样本进行填充或截断,使样本的长度一致。" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "def vectorizer(input, label=None, length=2000):\r\n", - " if label is not None:\r\n", - " for x, y in zip(input, label):\r\n", - " yield np.array((x + [0]*length)[:2000]).astype('int64'), np.array([y]).astype('int64')\r\n", - " else:\r\n", - " for x in input:\r\n", - " yield np.array((x + [0]*length)[:2000]).astype('int64')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 载入预训练向量。\n", - "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行\n", - "处理的方式进行匹配。这里略过此部分,如果感兴趣可以参考[此链接](https://aistudio.baidu.com/aistudio/projectdetail/496368)进一步了解。" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# 下载预训练向量文件,此链接下载较慢,较快下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", - "!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", - "!unzip -q glove.6B.zip\r\n", - "\r\n", - "glove_path = \"./glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", - "embeddings = {}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "我们先观察上述GloVe预训练向量文件一行的数据:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GloVe单行数据:'the -0.038194 -0.24487 0.72812 -0.39961 0.083172 0.043953 -0.39141 0.3344 -0.57545 0.087459 0.28787 -0.06731 0.30906 -0.26384 -0.13231 -0.20757 0.33395 -0.33848 -0.31743 -0.48336 0.1464 -0.37304 0.34577 0.052041 0.44946 -0.46971 0.02628 -0.54155 -0.15518 -0.14107 -0.039722 0.28277 0.14393 0.23464 -0.31021 0.086173 0.20397 0.52624 0.17164 -0.082378 -0.71787 -0.41531 0.20335 -0.12763 0.41367 0.55187 0.57908 -0.33477 -0.36559 -0.54857 -0.062892 0.26584 0.30205 0.99775 -0.80481 -3.0243 0.01254 -0.36942 2.2167 0.72201 -0.24978 0.92136 0.034514 0.46745 1.1079 -0.19358 -0.074575 0.23353 -0.052062 -0.22044 0.057162 -0.15806 -0.30798 -0.41625 0.37972 0.15006 -0.53212 -0.2055 -1.2526 0.071624 0.70565 0.49744 -0.42063 0.26148 -1.538 -0.30223 -0.073438 -0.28312 0.37104 -0.25217 0.016215 -0.017099 -0.38984 0.87424 -0.72569 -0.51058 -0.52028 -0.1459 0.8278 0.27062\n", - "'\n" - ] - } - ], - "source": [ - "# 使用utf8编码解码\r\n", - "with open(glove_path, encoding='utf-8') as gf:\r\n", - " line = gf.readline()\r\n", - " print(\"GloVe单行数据:'%s'\" % line)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "可以看到,每一行都以单词开头,其后接上该单词的向量值,各个值之间用空格隔开。基于此,可以用如下方法得到所有词向量的字典。" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "预训练词向量总数:400000\n", - "单词'the'的向量是:[-0.038194, -0.24487, 0.72812, -0.39961, 0.083172, 0.043953, -0.39141, 0.3344, -0.57545, 0.087459, 0.28787, -0.06731, 0.30906, -0.26384, -0.13231, -0.20757, 0.33395, -0.33848, -0.31743, -0.48336, 0.1464, -0.37304, 0.34577, 0.052041, 0.44946, -0.46971, 0.02628, -0.54155, -0.15518, -0.14107, -0.039722, 0.28277, 0.14393, 0.23464, -0.31021, 0.086173, 0.20397, 0.52624, 0.17164, -0.082378, -0.71787, -0.41531, 0.20335, -0.12763, 0.41367, 0.55187, 0.57908, -0.33477, -0.36559, -0.54857, -0.062892, 0.26584, 0.30205, 0.99775, -0.80481, -3.0243, 0.01254, -0.36942, 2.2167, 0.72201, -0.24978, 0.92136, 0.034514, 0.46745, 1.1079, -0.19358, -0.074575, 0.23353, -0.052062, -0.22044, 0.057162, -0.15806, -0.30798, -0.41625, 0.37972, 0.15006, -0.53212, -0.2055, -1.2526, 0.071624, 0.70565, 0.49744, -0.42063, 0.26148, -1.538, -0.30223, -0.073438, -0.28312, 0.37104, -0.25217, 0.016215, -0.017099, -0.38984, 0.87424, -0.72569, -0.51058, -0.52028, -0.1459, 0.8278, 0.27062]\n" - ] - } - ], - "source": [ - "with open(glove_path, encoding='utf-8') as gf:\r\n", - " for glove in gf:\r\n", - " word, embedding = glove.split(maxsplit=1)\r\n", - " embedding = [float(s) for s in embedding.split(' ')]\r\n", - " embeddings[word] = embedding\r\n", - "print(\"预训练词向量总数:%d\" % len(embeddings))\r\n", - "print(f\"单词'the'的向量是:{embeddings['the']}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 给数据集的词表匹配词向量\n", - "接下来,我们提取数据集的词表,需要注意的是,词表中的词编码的先后顺序是按照词出现的频率排列的,频率越高的词编码值越小。" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "词表的前5个单词:[b'the', b'and', b'a', b'of', b'to']\n", - "词表的后5个单词:[b'troubles', b'virtual', b'warriors', b'widely', '']\n" - ] - } - ], - "source": [ - "word_idx = imdb_train.word_idx\r\n", - "vocab = [w for w in word_idx.keys()]\r\n", - "print(f\"词表的前5个单词:{vocab[:5]}\")\r\n", - "print(f\"词表的后5个单词:{vocab[-5:]}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "观察词表的后5个单词,我们发现,最后一个词是\"\\\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", - "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('\\'并没有进行二进制编码,注意区分)。\n", - "接下来,我们给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,我们设该词的词\n", - "向量为零向量。" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# 定义词向量的维度,注意与预训练词向量保持一致\r\n", - "dim = 100\r\n", - "\r\n", - "vocab_embeddings = np.zeros((len(vocab), dim))\r\n", - "for ind, word in enumerate(vocab):\r\n", - " if word != '':\r\n", - " word = word.decode()\r\n", - " embedding = embeddings.get(word, np.zeros((dim,)))\r\n", - " vocab_embeddings[ind, :] = embedding" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 构建基于预训练向量的Embedding\n", - "对于预训练向量的Embedding,我们一般期望它的参数不再变动,所以要设置trainable=False。如果希望在此基础上训练参数,则需要\n", - "设置trainable=True。" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "pretrained_attr = pd.ParamAttr(name='embedding',\r\n", - " initializer=pd.nn.initializer.Assign(vocab_embeddings),\r\n", - " trainable=False)\r\n", - "embedding_layer = pd.nn.Embedding(num_embeddings=len(vocab),\r\n", - " embedding_dim=dim,\r\n", - " padding_idx=word_idx[''],\r\n", - " weight_attr=pretrained_attr)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 构建分类器\n", - "这里,我们构建简单的基于一维卷积的分类模型,其结构为:Embedding->Conv1D->Pool1D->Linear。在定义Linear时,由于需要知\n", - "道输入向量的维度,我们可以按照公式[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-beta/api/paddle/nn/layer/conv/Conv2d_cn.html)\n", - "来进行计算。这里给出计算的函数如下:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------------------------------------------------------------------------\n", - " Layer (type) Input Shape Output Shape Param # \n", - "===========================================================================\n", - " Embedding-1 [[1, 2000]] [1, 2000, 100] 514,700 \n", - " Conv1D-1 [[1, 2000, 100]] [1, 998, 10] 5,010 \n", - " ReLU-1 [[1, 998, 10]] [1, 998, 10] 0 \n", - " MaxPool1D-1 [[1, 998, 10]] [1, 998, 5] 0 \n", - " Flatten-1 [[1, 998, 5]] [1, 4990] 0 \n", - " Linear-1 [[1, 4990]] [1, 2] 9,982 \n", - " Softmax-1 [[1, 2]] [1, 2] 0 \n", - "===========================================================================\n", - "Total params: 529,692\n", - "Trainable params: 529,692\n", - "Non-trainable params: 0\n", - "---------------------------------------------------------------------------\n", - "Input size (MB): 0.01\n", - "Forward/backward pass size (MB): 1.75\n", - "Params size (MB): 2.02\n", - "Estimated Total Size (MB): 3.78\n", - "---------------------------------------------------------------------------\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "{'total_params': 529692, 'trainable_params': 529692}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def cal_output_shape(input_shape, out_channels, kernel_size, stride, padding=0, dilation=1):\r\n", - " return out_channels, int((input_shape + 2*padding - (dilation*(kernel_size - 1) + 1)) / stride) + 1\r\n", - "\r\n", - "\r\n", - "# 定义每个样本的长度\r\n", - "length = 2000\r\n", - "\r\n", - "# 定义卷积层参数\r\n", - "kernel_size = 5\r\n", - "out_channels = 10\r\n", - "stride = 2\r\n", - "padding = 0\r\n", - "\r\n", - "output_shape = cal_output_shape(length, out_channels, kernel_size, stride, padding)\r\n", - "output_shape = cal_output_shape(output_shape[1], output_shape[0], 2, 2, 0)\r\n", - "sim_model = pd.nn.Sequential(embedding_layer,\r\n", - " pd.nn.Conv1D(in_channels=dim, out_channels=out_channels, kernel_size=kernel_size,\r\n", - " stride=stride, padding=padding, data_format='NLC', bias_attr=True),\r\n", - " pd.nn.ReLU(),\r\n", - " pd.nn.MaxPool1D(kernel_size=2, stride=2),\r\n", - " pd.nn.Flatten(),\r\n", - " pd.nn.Linear(in_features=np.prod(output_shape), out_features=2, bias_attr=True),\r\n", - " pd.nn.Softmax())\r\n", - "\r\n", - "pd.summary(sim_model, input_size=(-1, length), dtypes='int64')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 读取数据,进行训练\n", - "我们可以利用飞桨2.0的io.Dataset模块来构建一个数据的读取器,方便地将数据进行分批训练。" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10\n", - "step 586/586 [==============================] - loss: 0.3736 - acc: 0.9740 - 6ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.5626 - acc: 0.8726 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 2/10\n", - "step 586/586 [==============================] - loss: 0.3499 - acc: 0.9748 - 6ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7976 - acc: 0.8651 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 3/10\n", - "step 586/586 [==============================] - loss: 0.3137 - acc: 0.9756 - 6ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.6264 - acc: 0.8701 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 4/10\n", - "step 586/586 [==============================] - loss: 0.3470 - acc: 0.9772 - 6ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.6550 - acc: 0.8714 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 5/10\n", - "step 586/586 [==============================] - loss: 0.3507 - acc: 0.9776 - 7ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7118 - acc: 0.8726 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 6/10\n", - "step 586/586 [==============================] - loss: 0.3466 - acc: 0.9781 - 7ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7157 - acc: 0.8725 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 7/10\n", - "step 586/586 [==============================] - loss: 0.3139 - acc: 0.9781 - 6ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7192 - acc: 0.8723 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 8/10\n", - "step 586/586 [==============================] - loss: 0.3481 - acc: 0.9801 - 7ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7283 - acc: 0.8722 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 9/10\n", - "step 586/586 [==============================] - loss: 0.3217 - acc: 0.9802 - 7ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.6697 - acc: 0.8629 - 2ms/step \n", - "Eval samples: 6250\n", - "Epoch 10/10\n", - "step 586/586 [==============================] - loss: 0.3466 - acc: 0.9807 - 7ms/step \n", - "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.7435 - acc: 0.8726 - 2ms/step \n", - "Eval samples: 6250\n" - ] - } - ], - "source": [ - "class DataReader(Dataset):\r\n", - " def __init__(self, input, label, length):\r\n", - " self.data = list(vectorizer(input, label, length=length))\r\n", - "\r\n", - " def __getitem__(self, idx):\r\n", - " return self.data[idx]\r\n", - "\r\n", - " def __len__(self):\r\n", - " return len(self.data)\r\n", - " \r\n", - "\r\n", - "# 定义输入格式\r\n", - "input_form = pd.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", - "label_form = pd.static.InputSpec(shape=[None, 1], dtype='int64', name='label')\r\n", - "\r\n", - "model = pd.Model(sim_model, input_form, label_form)\r\n", - "model.prepare(optimizer=pd.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),\r\n", - " loss=pd.nn.loss.CrossEntropyLoss(),\r\n", - " metrics=pd.metric.Accuracy())\r\n", - "\r\n", - "# 分割训练集和验证集\r\n", - "eval_length = int(len(train_x) * 1/4)\r\n", - "model.fit(train_data=DataReader(train_x[:-eval_length], train_y[:-eval_length], length),\r\n", - " eval_data=DataReader(train_x[-eval_length:], train_y[-eval_length:], length),\r\n", - " batch_size=32, epochs=10, verbose=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false - }, - "source": [ - "## 评估效果并用模型预测" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Eval begin...\n", - "step 782/782 [==============================] - loss: 0.4383 - acc: 0.8644 - 2ms/step \n", - "Eval samples: 25000\n", - "Predict begin...\n", - "step 10/10 [==============================] - 2ms/step \n", - "Predict samples: 10\n", - "预测的标签是:0, 实际标签是:0\n", - "预测的标签是:0, 实际标签是:0\n", - "预测的标签是:0, 实际标签是:0\n", - "预测的标签是:0, 实际标签是:0\n", - "预测的标签是:0, 实际标签是:0\n", - "预测的标签是:1, 实际标签是:1\n", - "预测的标签是:1, 实际标签是:1\n", - "预测的标签是:1, 实际标签是:1\n", - "预测的标签是:1, 实际标签是:1\n", - "预测的标签是:1, 实际标签是:1\n" - ] - } - ], - "source": [ - "# 评估\r\n", - "model.evaluate(eval_data=DataReader(test_x, test_y, length), batch_size=32, verbose=1)\r\n", - "\r\n", - "# 预测\r\n", - "true_y = test_y[100:105] + test_y[-110:-105]\r\n", - "pred_y = model.predict(DataReader(test_x[100:105] + test_x[-110:-105], None, length), batch_size=1)\r\n", - "\r\n", - "for index, y in enumerate(pred_y[0]):\r\n", - " print(\"预测的标签是:%d, 实际标签是:%d\" % (np.argmax(y), true_y[index]))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", - "language": "python", - "name": "py35-paddle1.2.0" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} From 902af1c68c0c049745b38fd6e9f34c6edc68e619 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 12 Dec 2020 16:22:58 +0800 Subject: [PATCH 04/14] =?UTF-8?q?=E9=A3=9E=E6=A1=A82.0=E5=BA=94=E7=94=A8?= =?UTF-8?q?=E6=A1=88=E4=BE=8B=E2=80=94=E2=80=94=E4=BD=BF=E7=94=A8=E9=A2=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=9A=84=E8=AF=8D=E5=90=91=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除了原先的位置,增加了预训练词向量文件夹 --- .../pretrained_word_embeddings.ipynb | 629 ++++++++++++++++++ 1 file changed, 629 insertions(+) create mode 100644 paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb diff --git a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb new file mode 100644 index 00000000..71acb0c8 --- /dev/null +++ b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb @@ -0,0 +1,629 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 使用预训练的词向量\n", + "\n", + "Author: [Dongyang Yan](623320480@qq.com, github.com/fiyen )\n", + "\n", + "Data created: 2020/11/23\n", + "\n", + "Last modified: 2020/12/12\n", + "\n", + "Description: Tutorial to classify Imdb data using pre-trained word embeddings in paddlepaddle 2.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 摘要\n", + "\n", + "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imbd将直接调用自飞桨2.0,同时,\n", + "利用预训练的词向量([GloVe embedding](http://nlp.stanford.edu/projects/glove/))完成任务。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 环境设置" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import paddle as pd\r\n", + "from paddle.io import Dataset\r\n", + "import numpy as np\r\n", + "import paddle.text as pt\r\n", + "import random" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 用飞桨2.0调用Imdb数据集\n", + "由于飞桨2.0提供了经过处理的Imdb数据集,我们可以方便地调用所需要的数据实例,省去了数据预处理的麻烦。目前,飞桨2.0以及内置的高质量\n", + "数据集包括Conll05st、Imdb、Imikolov、Movielens、HCIHousing、WMT14和WMT16等,未来还将提供更多常用数据集的调用接口。\n", + "\n", + "以下定义了调用imdb训练集合测试集的方法。其中,cutoff定义了构建词典的截止大小,即数据集中出现频率在cutoff以下的不予考虑;mode定义了返回的数据用于何种用途(test: \n", + "测试集,train: 训练集)。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "imdb_train = pt.Imdb(mode='train', cutoff=150)\r\n", + "imdb_test = pt.Imdb(mode='test', cutoff=150)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "调用Imdb得到的是经过编码的数据集,每个term对应一个唯一id,映射关系可以通过imdb_train.word_idx查看。将每一个样本即一条电影评论,表示成id序列。我们可以检查一下以上生成的数据内容:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集样本数量: 25000; 测试集样本数量: 25000\n", + "样本标签: {0, 1}\n", + "样本字典: [(b'the', 0), (b'and', 1), (b'a', 2), (b'of', 3), (b'to', 4), (b'is', 5), (b'in', 6), (b'it', 7), (b'i', 8), (b'this', 9)]\n", + "单个样本: [5146, 43, 71, 6, 1092, 14, 0, 878, 130, 151, 5146, 18, 281, 747, 0, 5146, 3, 5146, 2165, 37, 5146, 46, 5, 71, 4089, 377, 162, 46, 5, 32, 1287, 300, 35, 203, 2136, 565, 14, 2, 253, 26, 146, 61, 372, 1, 615, 5146, 5, 30, 0, 50, 3290, 6, 2148, 14, 0, 5146, 11, 17, 451, 24, 4, 127, 10, 0, 878, 130, 43, 2, 50, 5146, 751, 5146, 5, 2, 221, 3727, 6, 9, 1167, 373, 9, 5, 5146, 7, 5, 1343, 13, 2, 5146, 1, 250, 7, 98, 4270, 56, 2316, 0, 928, 11, 11, 9, 16, 5, 5146, 5146, 6, 50, 69, 27, 280, 27, 108, 1045, 0, 2633, 4177, 3180, 17, 1675, 1, 2571]\n", + "最小样本长度: 10;最大样本长度: 2469\n" + ] + } + ], + "source": [ + "print(\"训练集样本数量: %d; 测试集样本数量: %d\" % (len(imdb_train), len(imdb_test)))\r\n", + "print(f\"样本标签: {set(imdb_train.labels)}\")\r\n", + "print(f\"样本字典: {list(imdb_train.word_idx.items())[:10]}\")\r\n", + "print(f\"单个样本: {imdb_train.docs[0]}\")\r\n", + "print(f\"最小样本长度: {min([len(x) for x in imdb_train.docs])};最大样本长度: {max([len(x) for x in imdb_train.docs])}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "对于训练集,我们将数据的顺序打乱,以优化将要进行的分类模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "shuffle_index = list(range(len(imdb_train)))\r\n", + "random.shuffle(shuffle_index)\r\n", + "train_x = [imdb_train.docs[i] for i in shuffle_index]\r\n", + "train_y = [imdb_train.labels[i] for i in shuffle_index]\r\n", + "\r\n", + "test_x = imdb_test.docs\r\n", + "test_y = imdb_test.labels" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "从样本长度上可以看到,每个样本的长度是不相同的。然而,在模型的训练过程中,需要保证每个样本的长度相同,以便于构造矩阵进行批量运算。\n", + "因此,我们需要先对所有样本进行填充或截断,使样本的长度一致。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def vectorizer(input, label=None, length=2000):\r\n", + " if label is not None:\r\n", + " for x, y in zip(input, label):\r\n", + " yield np.array((x + [0]*length)[:2000]).astype('int64'), np.array([y]).astype('int64')\r\n", + " else:\r\n", + " for x in input:\r\n", + " yield np.array((x + [0]*length)[:2000]).astype('int64')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 载入预训练向量。\n", + "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行\n", + "处理的方式进行匹配。这里略过此部分,如果感兴趣可以参考[此链接](https://aistudio.baidu.com/aistudio/projectdetail/496368)进一步了解。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 下载预训练向量文件,此链接下载较慢,较快下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", + "#!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", + "#!unzip -q glove.6B.zip\r\n", + "\r\n", + "glove_path = \"./data/data42051/glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", + "embeddings = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "我们先观察上述GloVe预训练向量文件一行的数据:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GloVe单行数据:'the -0.038194 -0.24487 0.72812 -0.39961 0.083172 0.043953 -0.39141 0.3344 -0.57545 0.087459 0.28787 -0.06731 0.30906 -0.26384 -0.13231 -0.20757 0.33395 -0.33848 -0.31743 -0.48336 0.1464 -0.37304 0.34577 0.052041 0.44946 -0.46971 0.02628 -0.54155 -0.15518 -0.14107 -0.039722 0.28277 0.14393 0.23464 -0.31021 0.086173 0.20397 0.52624 0.17164 -0.082378 -0.71787 -0.41531 0.20335 -0.12763 0.41367 0.55187 0.57908 -0.33477 -0.36559 -0.54857 -0.062892 0.26584 0.30205 0.99775 -0.80481 -3.0243 0.01254 -0.36942 2.2167 0.72201 -0.24978 0.92136 0.034514 0.46745 1.1079 -0.19358 -0.074575 0.23353 -0.052062 -0.22044 0.057162 -0.15806 -0.30798 -0.41625 0.37972 0.15006 -0.53212 -0.2055 -1.2526 0.071624 0.70565 0.49744 -0.42063 0.26148 -1.538 -0.30223 -0.073438 -0.28312 0.37104 -0.25217 0.016215 -0.017099 -0.38984 0.87424 -0.72569 -0.51058 -0.52028 -0.1459 0.8278 0.27062\n", + "'\n" + ] + } + ], + "source": [ + "# 使用utf8编码解码\r\n", + "with open(glove_path, encoding='utf-8') as gf:\r\n", + " line = gf.readline()\r\n", + " print(\"GloVe单行数据:'%s'\" % line)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "可以看到,每一行都以单词开头,其后接上该单词的向量值,各个值之间用空格隔开。基于此,可以用如下方法得到所有词向量的字典。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "预训练词向量总数:400000\n", + "单词'the'的向量是:[-0.038194, -0.24487, 0.72812, -0.39961, 0.083172, 0.043953, -0.39141, 0.3344, -0.57545, 0.087459, 0.28787, -0.06731, 0.30906, -0.26384, -0.13231, -0.20757, 0.33395, -0.33848, -0.31743, -0.48336, 0.1464, -0.37304, 0.34577, 0.052041, 0.44946, -0.46971, 0.02628, -0.54155, -0.15518, -0.14107, -0.039722, 0.28277, 0.14393, 0.23464, -0.31021, 0.086173, 0.20397, 0.52624, 0.17164, -0.082378, -0.71787, -0.41531, 0.20335, -0.12763, 0.41367, 0.55187, 0.57908, -0.33477, -0.36559, -0.54857, -0.062892, 0.26584, 0.30205, 0.99775, -0.80481, -3.0243, 0.01254, -0.36942, 2.2167, 0.72201, -0.24978, 0.92136, 0.034514, 0.46745, 1.1079, -0.19358, -0.074575, 0.23353, -0.052062, -0.22044, 0.057162, -0.15806, -0.30798, -0.41625, 0.37972, 0.15006, -0.53212, -0.2055, -1.2526, 0.071624, 0.70565, 0.49744, -0.42063, 0.26148, -1.538, -0.30223, -0.073438, -0.28312, 0.37104, -0.25217, 0.016215, -0.017099, -0.38984, 0.87424, -0.72569, -0.51058, -0.52028, -0.1459, 0.8278, 0.27062]\n" + ] + } + ], + "source": [ + "with open(glove_path, encoding='utf-8') as gf:\r\n", + " for glove in gf:\r\n", + " word, embedding = glove.split(maxsplit=1)\r\n", + " embedding = [float(s) for s in embedding.split(' ')]\r\n", + " embeddings[word] = embedding\r\n", + "print(\"预训练词向量总数:%d\" % len(embeddings))\r\n", + "print(f\"单词'the'的向量是:{embeddings['the']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 给数据集的词表匹配词向量\n", + "接下来,我们提取数据集的词表,需要注意的是,词表中的词编码的先后顺序是按照词出现的频率排列的,频率越高的词编码值越小。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "词表的前5个单词:[b'the', b'and', b'a', b'of', b'to']\n", + "词表的后5个单词:[b'troubles', b'virtual', b'warriors', b'widely', '']\n" + ] + } + ], + "source": [ + "word_idx = imdb_train.word_idx\r\n", + "vocab = [w for w in word_idx.keys()]\r\n", + "print(f\"词表的前5个单词:{vocab[:5]}\")\r\n", + "print(f\"词表的后5个单词:{vocab[-5:]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "观察词表的后5个单词,我们发现,最后一个词是\"\\\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", + "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('\\'并没有进行二进制编码,注意区分)。\n", + "接下来,我们给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,我们设该词的词\n", + "向量为零向量。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 定义词向量的维度,注意与预训练词向量保持一致\r\n", + "dim = 100\r\n", + "\r\n", + "vocab_embeddings = np.zeros((len(vocab), dim))\r\n", + "for ind, word in enumerate(vocab):\r\n", + " if word != '':\r\n", + " word = word.decode()\r\n", + " embedding = embeddings.get(word, np.zeros((dim,)))\r\n", + " vocab_embeddings[ind, :] = embedding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 构建基于预训练向量的Embedding\n", + "对于预训练向量的Embedding,我们一般期望它的参数不再变动,所以要设置trainable=False。如果希望在此基础上训练参数,则需要\n", + "设置trainable=True。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "pretrained_attr = pd.ParamAttr(name='embedding',\r\n", + " initializer=pd.nn.initializer.Assign(vocab_embeddings),\r\n", + " trainable=False)\r\n", + "embedding_layer = pd.nn.Embedding(num_embeddings=len(vocab),\r\n", + " embedding_dim=dim,\r\n", + " padding_idx=word_idx[''],\r\n", + " weight_attr=pretrained_attr)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 构建分类器\n", + "这里,我们构建简单的基于一维卷积的分类模型,其结构为:Embedding->Conv1D->Pool1D->Linear。在定义Linear时,由于需要知\n", + "道输入向量的维度,我们可以按照公式[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-beta/api/paddle/nn/layer/conv/Conv2d_cn.html)\n", + "来进行计算。这里给出计算的函数如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "===========================================================================\n", + " Embedding-1 [[1, 2000]] [1, 2000, 100] 514,700 \n", + " Conv1D-1 [[1, 2000, 100]] [1, 998, 10] 5,010 \n", + " ReLU-1 [[1, 998, 10]] [1, 998, 10] 0 \n", + " MaxPool1D-1 [[1, 998, 10]] [1, 998, 5] 0 \n", + " Flatten-1 [[1, 998, 5]] [1, 4990] 0 \n", + " Linear-1 [[1, 4990]] [1, 2] 9,982 \n", + " Softmax-1 [[1, 2]] [1, 2] 0 \n", + "===========================================================================\n", + "Total params: 529,692\n", + "Trainable params: 529,692\n", + "Non-trainable params: 0\n", + "---------------------------------------------------------------------------\n", + "Input size (MB): 0.01\n", + "Forward/backward pass size (MB): 1.75\n", + "Params size (MB): 2.02\n", + "Estimated Total Size (MB): 3.78\n", + "---------------------------------------------------------------------------\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'total_params': 529692, 'trainable_params': 529692}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def cal_output_shape(input_shape, out_channels, kernel_size, stride, padding=0, dilation=1):\r\n", + " return out_channels, int((input_shape + 2*padding - (dilation*(kernel_size - 1) + 1)) / stride) + 1\r\n", + "\r\n", + "\r\n", + "# 定义每个样本的长度\r\n", + "length = 2000\r\n", + "\r\n", + "# 定义卷积层参数\r\n", + "kernel_size = 5\r\n", + "out_channels = 10\r\n", + "stride = 2\r\n", + "padding = 0\r\n", + "\r\n", + "output_shape = cal_output_shape(length, out_channels, kernel_size, stride, padding)\r\n", + "output_shape = cal_output_shape(output_shape[1], output_shape[0], 2, 2, 0)\r\n", + "sim_model = pd.nn.Sequential(embedding_layer,\r\n", + " pd.nn.Conv1D(in_channels=dim, out_channels=out_channels, kernel_size=kernel_size,\r\n", + " stride=stride, padding=padding, data_format='NLC', bias_attr=True),\r\n", + " pd.nn.ReLU(),\r\n", + " pd.nn.MaxPool1D(kernel_size=2, stride=2),\r\n", + " pd.nn.Flatten(),\r\n", + " pd.nn.Linear(in_features=np.prod(output_shape), out_features=2, bias_attr=True),\r\n", + " pd.nn.Softmax())\r\n", + "\r\n", + "pd.summary(sim_model, input_size=(-1, length), dtypes='int64')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 读取数据,进行训练\n", + "我们可以利用飞桨2.0的io.Dataset模块来构建一个数据的读取器,方便地将数据进行分批训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "step 586/586 [==============================] - loss: 0.3141 - acc: 0.9740 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.3515 - acc: 0.8630 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 2/10\n", + "step 586/586 [==============================] - loss: 0.3516 - acc: 0.9751 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4115 - acc: 0.8698 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 3/10\n", + "step 586/586 [==============================] - loss: 0.3136 - acc: 0.9766 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4114 - acc: 0.8747 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 4/10\n", + "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9777 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4132 - acc: 0.8659 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 5/10\n", + "step 586/586 [==============================] - loss: 0.3135 - acc: 0.9768 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4135 - acc: 0.8677 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 6/10\n", + "step 586/586 [==============================] - loss: 0.3137 - acc: 0.9777 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4185 - acc: 0.8662 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 7/10\n", + "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9790 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4134 - acc: 0.8722 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 8/10\n", + "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9797 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4131 - acc: 0.8672 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 9/10\n", + "step 586/586 [==============================] - loss: 0.3800 - acc: 0.9793 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4161 - acc: 0.8627 - 2ms/step \n", + "Eval samples: 6250\n", + "Epoch 10/10\n", + "step 586/586 [==============================] - loss: 0.3564 - acc: 0.9795 - 6ms/step \n", + "Eval begin...\n", + "step 196/196 [==============================] - loss: 0.4151 - acc: 0.8678 - 3ms/step \n", + "Eval samples: 6250\n" + ] + } + ], + "source": [ + "class DataReader(Dataset):\r\n", + " def __init__(self, input, label, length):\r\n", + " self.data = list(vectorizer(input, label, length=length))\r\n", + "\r\n", + " def __getitem__(self, idx):\r\n", + " return self.data[idx]\r\n", + "\r\n", + " def __len__(self):\r\n", + " return len(self.data)\r\n", + " \r\n", + "\r\n", + "# 定义输入格式\r\n", + "input_form = pd.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", + "label_form = pd.static.InputSpec(shape=[None, 1], dtype='int64', name='label')\r\n", + "\r\n", + "model = pd.Model(sim_model, input_form, label_form)\r\n", + "model.prepare(optimizer=pd.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),\r\n", + " loss=pd.nn.loss.CrossEntropyLoss(),\r\n", + " metrics=pd.metric.Accuracy())\r\n", + "\r\n", + "# 分割训练集和验证集\r\n", + "eval_length = int(len(train_x) * 1/4)\r\n", + "model.fit(train_data=DataReader(train_x[:-eval_length], train_y[:-eval_length], length),\r\n", + " eval_data=DataReader(train_x[-eval_length:], train_y[-eval_length:], length),\r\n", + " batch_size=32, epochs=10, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 评估效果并用模型预测" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval begin...\n", + "step 782/782 [==============================] - loss: 0.3144 - acc: 0.8558 - 2ms/step \n", + "Eval samples: 25000\n", + "Predict begin...\n", + "step 10/10 [==============================] - 2ms/step \n", + "Predict samples: 10\n", + "原文本:albert and tom are brilliant as sir and his of course the play is brilliant to begin with and nothing can compare with the and of theatre and i think you listen better in theatre but on the screen we become more intimate were more than we are in the theatre we witness subtle changes in expression we see better as well as listen both the play and the movie are moving intelligent the story of the company of historical context of the two main characters and of the parallel characters in itself if you cannot get to see it in a theatre i dont imagine its produced much these days then please do yourself a favor and get the video\n", + "预测的标签是:0, 实际标签是:0\n", + "原文本:this film has its and may some folks who frankly need a good the head but the film is top notch in every way engaging poignant relevant naturally is larger than life makes an ideal i thought the performances to be terribly strong in both leads and character provides plenty of dark humor the period is well captured the supporting cast well chosen this is to be seen and like a fine i only wish it were out on dvd\n", + "预测的标签是:0, 实际标签是:0\n", + "原文本:this is a movie that deserves another you havent seen it for a while or a first you were too young when it came out 1983 based on a play by the same name it is the story of an older actor who heads a company in england during world war ii it deals with his stress of trying to perform a shakespeare each night while facing problems such as theaters and a company made up of older or physically young able ones being taken for military service it also deals with his relationship with various members of his company especially with his so far it all sounds rather dull but nothing could be further from the truth while tragic overall the story is told with a lot of humor and emotions run high throughout the two male leads both received oscar for best actor and so i strongly recommend this movie to anyone who enjoys human drama shakespeare or who has ever worked in any the make up another of the movie that will be fascinating to most viewers\n", + "预测的标签是:0, 实际标签是:0\n", + "原文本:sir has played over tonight he cant remember his opening at the eyes reflect the kings madness his him the is an air of desperation about both these great actor knowing his powers are major wife aware of his into madness and knowing he is to do more than ease his passing the is really a love story between the the years they have become on one another to the extent that neither can a future without the other set during the second world concerns the of a frankly second rate company an equal number of has and led by sir a theatrical knight of what might be called the old part he is playing he stage and out over the his audience into inside most of the time deep beneath the he still remains an occasional of his earlier is to catch a glimpse of this that his audiences hope for mr very cleverly on the to the point of when you are ready to his performance as mere and he will produce a moment of subtlety and that makes you realise that a great actor is playing a great actor the same goes for mr easy to write off his br of norman as an exercise in we have a middle aged rather than camp theatrical his way through the company of the girls and loving the wicked in the were and i strongly suspect still are many men just like norman in the kind and more about the plays than many of the run with wisdom and believe the vast majority of them would with laughter at mr portrait i saw the on the london stage the norman was rather more than in the was played by the great mr jones to huge from the was a memorable performance that mr him rather to an also ran as opposed to an actor on level idea that sir and norman might be almost without each other went right out of the window norman was reduced to being his im not sure was what made for breathtaking theatre and the balance in the to the relationship both men have come a long way since their early appearances in the british new wave pictures when they became the of the vaguely class and ashamed of it the british cinema virtually committed in the 1970s they on the theatre apart from a few roles to keep the wolf from the the of more in the bright br the their with energy and talent to the world at large were still not a big movie but is a great one\n", + "预测的标签是:0, 实际标签是:0\n", + "原文本:anyone who fine acting and dialogue will br this film taken from taking sides its a funnybr br and ultimately of a relationship between br very types albert is as the br actor who barely the world war br around him so intent is he on the of his br company and his own psychological and emotional br tom is as norman the of the br whose apparent turns out to be anything but br really a must see\n", + "预测的标签是:0, 实际标签是:0\n", + "原文本:well i guess i know the answer to that question for the money we have been so with cat in the hat advertising and that we almost believe there has to be something good about this movie i admit i thought the trailers looked bad but i still had to give it a chance well i should have went with my it was a complete piece hollywood trash once again that the average person can be into believing anything they say is good must be good aside from the insulting fact that the film is only about 80 minutes long it obviously started with a eaten script its full of failed attempts at senseless humor and awful it jumps all over the universe with no nor direction this is then with yes ill say it bad acting i couldnt help but feel like i was watching coffee talk on every time mike myers opened his mouth was the cat intended to be a middle aged jewish woman and were no prize either but mr myers should disappear under a rock somewhere until hes ready to make another austin powers movie f no stars 0 on a scale of 110 save your money\n", + "预测的标签是:1, 实际标签是:1\n", + "原文本:when my own child is me to leave the opening show of this film i know it is bad i wanted to my eyes out i wanted to reach through the screen and slap mike myers for the last of dignity he had this is one of the few films in my life i have watched and immediately wished to if only it were possible the other films being 2 and fast and both which are better than this crap in the br i may drink myself to sleep tonight in a attempt to forget i ever witnessed this on the good br to mike myers i say stick with austin or even world just because it worked for jim carrey doesnt mean is a success for all br\n", + "预测的标签是:1, 实际标签是:1\n", + "原文本:holy what a piece of this movie is i didnt how these filmmakers could take a word book and turn it into a movie i guess they didnt know either i dont remember any or in the book do youbr br they took this all times childrens classic added some and sexual and it into a joke this should give you a good idea of what these hollywood producers think like i have to say visually it was interesting but the brilliant visual story is ruined by toilet humor if you even think that kind of thing is funny i dont want the kids that i know to think it isbr br dont take your kids to see dont rent the dvd i hope the ghost of doctor ghost comes and the people that made this movie\n", + "预测的标签是:1, 实际标签是:1\n", + "原文本:i was so looking forward to seeing this when it was in it turned out to be the the biggest let down a far cry from the world of dr it was and i dont think dr would have the stole christmas was much better i understand it had some subtle adult jokes in it but my children have yet to catch on whereas the cat in the hat they caught a lot more than i would have up with dr it really bothered me to see how this timeless classic got on the big screen lets see what they do with a hope this one does dr some justice\n", + "预测的标签是:1, 实际标签是:1\n", + "原文本:ive seen some bad things in my time a half dead trying to get out of high a head on between two cars a thousand on a kitchen floor human beings living like br but never in my life have i seen anything as bad as the cat in the br this film is worse than 911 worse than hitler worse than the worse than people who put in br it is the most disturbing film of all time br i used to think it was a joke some elaborate joke and that mike myers was maybe a high drug who lost a bet or br i\n", + "预测的标签是:1, 实际标签是:1\n" + ] + } + ], + "source": [ + "# 评估\r\n", + "model.evaluate(eval_data=DataReader(test_x, test_y, length), batch_size=32, verbose=1)\r\n", + "\r\n", + "# 预测\r\n", + "true_y = test_y[100:105] + test_y[-110:-105]\r\n", + "pred_y = model.predict(DataReader(test_x[100:105] + test_x[-110:-105], None, length), batch_size=1)\r\n", + "test_x_doc = test_x[100:105] + test_x[-110:-105]\r\n", + "\r\n", + "for index, y in enumerate(pred_y[0]):\r\n", + " print(\"原文本:%s\" % ' '.join([vocab[i].decode() for i in test_x_doc[index] if i < len(vocab) - 1]))\r\n", + " print(\"预测的标签是:%d, 实际标签是:%d\" % (np.argmax(y), true_y[index]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From 53f9677fbfac8354ccab69a6051c1e0954819d0a Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Mon, 28 Dec 2020 11:20:08 +0800 Subject: [PATCH 05/14] Update pretrained_word_embeddings.ipynb MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改了两处说明(Line9,190),修改了最近修改时间(Line15) --- .../pretrained_word_embeddings.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb index 71acb0c8..6c949511 100644 --- a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb +++ b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb @@ -6,13 +6,13 @@ "collapsed": false }, "source": [ - "# 使用预训练的词向量\n", + "# 使用预训练的词向量完成文本分类任务\n", "\n", "Author: [Dongyang Yan](623320480@qq.com, github.com/fiyen )\n", "\n", "Data created: 2020/11/23\n", "\n", - "Last modified: 2020/12/12\n", + "Last modified: 2020/12/28\n", "\n", "Description: Tutorial to classify Imdb data using pre-trained word embeddings in paddlepaddle 2.0" ] @@ -187,7 +187,7 @@ }, "outputs": [], "source": [ - "# 下载预训练向量文件,此链接下载较慢,较快下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", + "# 下载预训练向量文件,此链接下载较慢,推荐从AI Studio的公开数据集进行下载(网址:https://aistudio.baidu.com/aistudio/datasetoverview),此文件的下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", "#!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", "#!unzip -q glove.6B.zip\r\n", "\r\n", From d92b89c68bc907ae57076f13d771616a11172a1e Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Thu, 4 Mar 2021 16:26:53 +0800 Subject: [PATCH 06/14] modified based on the latest comments --- .../pretrained_word_embeddings.ipynb | 190 +++++++++--------- 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb index 6c949511..a1004324 100644 --- a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb +++ b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb @@ -8,13 +8,11 @@ "source": [ "# 使用预训练的词向量完成文本分类任务\n", "\n", - "Author: [Dongyang Yan](623320480@qq.com, github.com/fiyen )\n", + "**作者**: [fiyen](https://github.com/fiyen)\n", "\n", - "Data created: 2020/11/23\n", + "**日期**: 2021.03\n", "\n", - "Last modified: 2020/12/28\n", - "\n", - "Description: Tutorial to classify Imdb data using pre-trained word embeddings in paddlepaddle 2.0" + "**摘要**: 本示例教程将会演示如何使用飞桨内置的Imdb数据集,并使用预训练词向量进行文本分类。" ] }, { @@ -25,7 +23,7 @@ "source": [ "## 摘要\n", "\n", - "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imbd将直接调用自飞桨2.0,同时,\n", + "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imdb将直接调用自飞桨2.0,同时,\n", "利用预训练的词向量([GloVe embedding](http://nlp.stanford.edu/projects/glove/))完成任务。" ] }, @@ -40,16 +38,16 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "import paddle as pd\r\n", + "import paddle\r\n", "from paddle.io import Dataset\r\n", "import numpy as np\r\n", - "import paddle.text as pt\r\n", + "import paddle.text as text\r\n", "import random" ] }, @@ -69,14 +67,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "imdb_train = pt.Imdb(mode='train', cutoff=150)\r\n", - "imdb_test = pt.Imdb(mode='test', cutoff=150)" + "imdb_train = text.Imdb(mode='train', cutoff=150)\r\n", + "imdb_test = text.Imdb(mode='test', cutoff=150)" ] }, { @@ -90,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "collapsed": false }, @@ -126,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "collapsed": false }, @@ -153,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, @@ -175,23 +173,22 @@ }, "source": [ "## 载入预训练向量。\n", - "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行\n", - "处理的方式进行匹配。这里略过此部分,如果感兴趣可以参考[此链接](https://aistudio.baidu.com/aistudio/projectdetail/496368)进一步了解。" + "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行处理的方式进行匹配。" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "# 下载预训练向量文件,此链接下载较慢,推荐从AI Studio的公开数据集进行下载(网址:https://aistudio.baidu.com/aistudio/datasetoverview),此文件的下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", - "#!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", - "#!unzip -q glove.6B.zip\r\n", + "# 下载预训练向量文件,此链接下载较慢,推荐从AI Studio的公开数据集进行下载,此文件的下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", + "!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", + "!unzip -q glove.6B.zip\r\n", "\r\n", - "glove_path = \"./data/data42051/glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", + "glove_path = \"./glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", "embeddings = {}" ] }, @@ -206,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false }, @@ -238,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false }, @@ -345,13 +342,13 @@ }, "outputs": [], "source": [ - "pretrained_attr = pd.ParamAttr(name='embedding',\r\n", - " initializer=pd.nn.initializer.Assign(vocab_embeddings),\r\n", - " trainable=False)\r\n", - "embedding_layer = pd.nn.Embedding(num_embeddings=len(vocab),\r\n", - " embedding_dim=dim,\r\n", - " padding_idx=word_idx[''],\r\n", - " weight_attr=pretrained_attr)" + "pretrained_attr = paddle.ParamAttr(name='embedding',\r\n", + " initializer=paddle.nn.initializer.Assign(vocab_embeddings),\r\n", + " trainable=False)\r\n", + "embedding_layer = paddle.nn.Embedding(num_embeddings=len(vocab),\r\n", + " embedding_dim=dim,\r\n", + " padding_idx=word_idx[''],\r\n", + " weight_attr=pretrained_attr)" ] }, { @@ -389,8 +386,8 @@ " Softmax-1 [[1, 2]] [1, 2] 0 \n", "===========================================================================\n", "Total params: 529,692\n", - "Trainable params: 529,692\n", - "Non-trainable params: 0\n", + "Trainable params: 14,992\n", + "Non-trainable params: 514,700\n", "---------------------------------------------------------------------------\n", "Input size (MB): 0.01\n", "Forward/backward pass size (MB): 1.75\n", @@ -403,7 +400,7 @@ { "data": { "text/plain": [ - "{'total_params': 529692, 'trainable_params': 529692}" + "{'total_params': 529692, 'trainable_params': 14992}" ] }, "execution_count": 12, @@ -427,16 +424,16 @@ "\r\n", "output_shape = cal_output_shape(length, out_channels, kernel_size, stride, padding)\r\n", "output_shape = cal_output_shape(output_shape[1], output_shape[0], 2, 2, 0)\r\n", - "sim_model = pd.nn.Sequential(embedding_layer,\r\n", - " pd.nn.Conv1D(in_channels=dim, out_channels=out_channels, kernel_size=kernel_size,\r\n", - " stride=stride, padding=padding, data_format='NLC', bias_attr=True),\r\n", - " pd.nn.ReLU(),\r\n", - " pd.nn.MaxPool1D(kernel_size=2, stride=2),\r\n", - " pd.nn.Flatten(),\r\n", - " pd.nn.Linear(in_features=np.prod(output_shape), out_features=2, bias_attr=True),\r\n", - " pd.nn.Softmax())\r\n", + "sim_model = paddle.nn.Sequential(embedding_layer,\r\n", + " paddle.nn.Conv1D(in_channels=dim, out_channels=out_channels, kernel_size=kernel_size,\r\n", + " stride=stride, padding=padding, data_format='NLC', bias_attr=True),\r\n", + " paddle.nn.ReLU(),\r\n", + " paddle.nn.MaxPool1D(kernel_size=2, stride=2),\r\n", + " paddle.nn.Flatten(),\r\n", + " paddle.nn.Linear(in_features=np.prod(output_shape), out_features=2, bias_attr=True),\r\n", + " paddle.nn.Softmax())\r\n", "\r\n", - "pd.summary(sim_model, input_size=(-1, length), dtypes='int64')" + "paddle.summary(sim_model, input_size=(-1, length), dtypes='int64')" ] }, { @@ -451,7 +448,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": { "collapsed": false }, @@ -460,55 +457,66 @@ "name": "stdout", "output_type": "stream", "text": [ + "The loss value printed in the log is the current step, and the metric is the average value of previous step.\n", "Epoch 1/10\n", - "step 586/586 [==============================] - loss: 0.3141 - acc: 0.9740 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.7259 - acc: 0.7708 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.3515 - acc: 0.8630 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.5196 - acc: 0.7006 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 2/10\n", - "step 586/586 [==============================] - loss: 0.3516 - acc: 0.9751 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.5012 - acc: 0.8090 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4115 - acc: 0.8698 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.5776 - acc: 0.7886 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 3/10\n", - "step 586/586 [==============================] - loss: 0.3136 - acc: 0.9766 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.5459 - acc: 0.8248 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4114 - acc: 0.8747 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4988 - acc: 0.8182 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 4/10\n", - "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9777 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.4274 - acc: 0.8431 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4132 - acc: 0.8659 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4896 - acc: 0.8051 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 5/10\n", - "step 586/586 [==============================] - loss: 0.3135 - acc: 0.9768 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.4212 - acc: 0.8501 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4135 - acc: 0.8677 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.5174 - acc: 0.8144 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 6/10\n", - "step 586/586 [==============================] - loss: 0.3137 - acc: 0.9777 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.4084 - acc: 0.8605 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4185 - acc: 0.8662 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4635 - acc: 0.8266 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 7/10\n", - "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9790 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.4713 - acc: 0.8697 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4134 - acc: 0.8722 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.5352 - acc: 0.8222 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 8/10\n", - "step 586/586 [==============================] - loss: 0.3133 - acc: 0.9797 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.5050 - acc: 0.8745 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4131 - acc: 0.8672 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4925 - acc: 0.8248 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 9/10\n", - "step 586/586 [==============================] - loss: 0.3800 - acc: 0.9793 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.5348 - acc: 0.8832 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4161 - acc: 0.8627 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4936 - acc: 0.8078 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 10/10\n", - "step 586/586 [==============================] - loss: 0.3564 - acc: 0.9795 - 6ms/step \n", + "step 586/586 [==============================] - loss: 0.5156 - acc: 0.8846 - 4ms/step \n", "Eval begin...\n", - "step 196/196 [==============================] - loss: 0.4151 - acc: 0.8678 - 3ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 196/196 [==============================] - loss: 0.4882 - acc: 0.8286 - 3ms/step \n", "Eval samples: 6250\n" ] } @@ -526,13 +534,13 @@ " \r\n", "\r\n", "# 定义输入格式\r\n", - "input_form = pd.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", - "label_form = pd.static.InputSpec(shape=[None, 1], dtype='int64', name='label')\r\n", + "input_form = paddle.static.InputSpec(shape=[None, length], dtype='int64', name='input')\r\n", + "label_form = paddle.static.InputSpec(shape=[None, 1], dtype='int64', name='label')\r\n", "\r\n", - "model = pd.Model(sim_model, input_form, label_form)\r\n", - "model.prepare(optimizer=pd.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),\r\n", - " loss=pd.nn.loss.CrossEntropyLoss(),\r\n", - " metrics=pd.metric.Accuracy())\r\n", + "model = paddle.Model(sim_model, input_form, label_form)\r\n", + "model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()),\r\n", + " loss=paddle.nn.loss.CrossEntropyLoss(),\r\n", + " metrics=paddle.metric.Accuracy())\r\n", "\r\n", "# 分割训练集和验证集\r\n", "eval_length = int(len(train_x) * 1/4)\r\n", @@ -552,7 +560,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": { "collapsed": false }, @@ -562,31 +570,32 @@ "output_type": "stream", "text": [ "Eval begin...\n", - "step 782/782 [==============================] - loss: 0.3144 - acc: 0.8558 - 2ms/step \n", + "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", + "step 782/782 [==============================] - loss: 0.4061 - acc: 0.8207 - 3ms/step \n", "Eval samples: 25000\n", "Predict begin...\n", "step 10/10 [==============================] - 2ms/step \n", "Predict samples: 10\n", "原文本:albert and tom are brilliant as sir and his of course the play is brilliant to begin with and nothing can compare with the and of theatre and i think you listen better in theatre but on the screen we become more intimate were more than we are in the theatre we witness subtle changes in expression we see better as well as listen both the play and the movie are moving intelligent the story of the company of historical context of the two main characters and of the parallel characters in itself if you cannot get to see it in a theatre i dont imagine its produced much these days then please do yourself a favor and get the video\n", - "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:positive, 实际标签是:positive\n", "原文本:this film has its and may some folks who frankly need a good the head but the film is top notch in every way engaging poignant relevant naturally is larger than life makes an ideal i thought the performances to be terribly strong in both leads and character provides plenty of dark humor the period is well captured the supporting cast well chosen this is to be seen and like a fine i only wish it were out on dvd\n", - "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:positive, 实际标签是:positive\n", "原文本:this is a movie that deserves another you havent seen it for a while or a first you were too young when it came out 1983 based on a play by the same name it is the story of an older actor who heads a company in england during world war ii it deals with his stress of trying to perform a shakespeare each night while facing problems such as theaters and a company made up of older or physically young able ones being taken for military service it also deals with his relationship with various members of his company especially with his so far it all sounds rather dull but nothing could be further from the truth while tragic overall the story is told with a lot of humor and emotions run high throughout the two male leads both received oscar for best actor and so i strongly recommend this movie to anyone who enjoys human drama shakespeare or who has ever worked in any the make up another of the movie that will be fascinating to most viewers\n", - "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:positive, 实际标签是:positive\n", "原文本:sir has played over tonight he cant remember his opening at the eyes reflect the kings madness his him the is an air of desperation about both these great actor knowing his powers are major wife aware of his into madness and knowing he is to do more than ease his passing the is really a love story between the the years they have become on one another to the extent that neither can a future without the other set during the second world concerns the of a frankly second rate company an equal number of has and led by sir a theatrical knight of what might be called the old part he is playing he stage and out over the his audience into inside most of the time deep beneath the he still remains an occasional of his earlier is to catch a glimpse of this that his audiences hope for mr very cleverly on the to the point of when you are ready to his performance as mere and he will produce a moment of subtlety and that makes you realise that a great actor is playing a great actor the same goes for mr easy to write off his br of norman as an exercise in we have a middle aged rather than camp theatrical his way through the company of the girls and loving the wicked in the were and i strongly suspect still are many men just like norman in the kind and more about the plays than many of the run with wisdom and believe the vast majority of them would with laughter at mr portrait i saw the on the london stage the norman was rather more than in the was played by the great mr jones to huge from the was a memorable performance that mr him rather to an also ran as opposed to an actor on level idea that sir and norman might be almost without each other went right out of the window norman was reduced to being his im not sure was what made for breathtaking theatre and the balance in the to the relationship both men have come a long way since their early appearances in the british new wave pictures when they became the of the vaguely class and ashamed of it the british cinema virtually committed in the 1970s they on the theatre apart from a few roles to keep the wolf from the the of more in the bright br the their with energy and talent to the world at large were still not a big movie but is a great one\n", - "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:positive, 实际标签是:positive\n", "原文本:anyone who fine acting and dialogue will br this film taken from taking sides its a funnybr br and ultimately of a relationship between br very types albert is as the br actor who barely the world war br around him so intent is he on the of his br company and his own psychological and emotional br tom is as norman the of the br whose apparent turns out to be anything but br really a must see\n", - "预测的标签是:0, 实际标签是:0\n", + "预测的标签是:positive, 实际标签是:positive\n", "原文本:well i guess i know the answer to that question for the money we have been so with cat in the hat advertising and that we almost believe there has to be something good about this movie i admit i thought the trailers looked bad but i still had to give it a chance well i should have went with my it was a complete piece hollywood trash once again that the average person can be into believing anything they say is good must be good aside from the insulting fact that the film is only about 80 minutes long it obviously started with a eaten script its full of failed attempts at senseless humor and awful it jumps all over the universe with no nor direction this is then with yes ill say it bad acting i couldnt help but feel like i was watching coffee talk on every time mike myers opened his mouth was the cat intended to be a middle aged jewish woman and were no prize either but mr myers should disappear under a rock somewhere until hes ready to make another austin powers movie f no stars 0 on a scale of 110 save your money\n", - "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:negative, 实际标签是:negative\n", "原文本:when my own child is me to leave the opening show of this film i know it is bad i wanted to my eyes out i wanted to reach through the screen and slap mike myers for the last of dignity he had this is one of the few films in my life i have watched and immediately wished to if only it were possible the other films being 2 and fast and both which are better than this crap in the br i may drink myself to sleep tonight in a attempt to forget i ever witnessed this on the good br to mike myers i say stick with austin or even world just because it worked for jim carrey doesnt mean is a success for all br\n", - "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:negative, 实际标签是:negative\n", "原文本:holy what a piece of this movie is i didnt how these filmmakers could take a word book and turn it into a movie i guess they didnt know either i dont remember any or in the book do youbr br they took this all times childrens classic added some and sexual and it into a joke this should give you a good idea of what these hollywood producers think like i have to say visually it was interesting but the brilliant visual story is ruined by toilet humor if you even think that kind of thing is funny i dont want the kids that i know to think it isbr br dont take your kids to see dont rent the dvd i hope the ghost of doctor ghost comes and the people that made this movie\n", - "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:negative, 实际标签是:negative\n", "原文本:i was so looking forward to seeing this when it was in it turned out to be the the biggest let down a far cry from the world of dr it was and i dont think dr would have the stole christmas was much better i understand it had some subtle adult jokes in it but my children have yet to catch on whereas the cat in the hat they caught a lot more than i would have up with dr it really bothered me to see how this timeless classic got on the big screen lets see what they do with a hope this one does dr some justice\n", - "预测的标签是:1, 实际标签是:1\n", + "预测的标签是:negative, 实际标签是:negative\n", "原文本:ive seen some bad things in my time a half dead trying to get out of high a head on between two cars a thousand on a kitchen floor human beings living like br but never in my life have i seen anything as bad as the cat in the br this film is worse than 911 worse than hitler worse than the worse than people who put in br it is the most disturbing film of all time br i used to think it was a joke some elaborate joke and that mike myers was maybe a high drug who lost a bet or br i\n", - "预测的标签是:1, 实际标签是:1\n" + "预测的标签是:negative, 实际标签是:negative\n" ] } ], @@ -599,9 +608,12 @@ "pred_y = model.predict(DataReader(test_x[100:105] + test_x[-110:-105], None, length), batch_size=1)\r\n", "test_x_doc = test_x[100:105] + test_x[-110:-105]\r\n", "\r\n", + "# 标签编码转文字\r\n", + "label_id2text = {0: 'positive', 1: 'negative'}\r\n", + "\r\n", "for index, y in enumerate(pred_y[0]):\r\n", " print(\"原文本:%s\" % ' '.join([vocab[i].decode() for i in test_x_doc[index] if i < len(vocab) - 1]))\r\n", - " print(\"预测的标签是:%d, 实际标签是:%d\" % (np.argmax(y), true_y[index]))" + " print(\"预测的标签是:%s, 实际标签是:%s\" % (label_id2text[np.argmax(y)], label_id2text[true_y[index]]))" ] } ], @@ -610,18 +622,6 @@ "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", "language": "python", "name": "py35-paddle1.2.0" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" } }, "nbformat": 4, From 88a02f449513520580676dffd41bf938ede6029c Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 17 Apr 2021 13:05:42 +0800 Subject: [PATCH 07/14] Create README.md --- paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md b/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md new file mode 100644 index 00000000..e9dff202 --- /dev/null +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/README.md @@ -0,0 +1 @@ +基于预训练BERT的古诗生成器 From 02a77442b721e4921681677628cb2a4ac85bbc65 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 17 Apr 2021 13:07:01 +0800 Subject: [PATCH 08/14] Add files via upload --- ...retrained_bert_for_poetry_generation.ipynb | 934 ++++++++++++++++++ 1 file changed, 934 insertions(+) create mode 100644 paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb new file mode 100644 index 00000000..ec6794f2 --- /dev/null +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -0,0 +1,934 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 用BERT实现自动写诗\n", + "\n", + "**作者**:[fiyen](https://github.com/fiyen)\n", + "\n", + "**日期**:2021.04\n", + "\n", + "**摘要**:本示例教程将会演示如何使用飞桨2.0以及PaddleNLP快速实现用BERT预训练模型生成高质量诗歌。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 摘要\n", + "在这个示例中,我们将快速构建基于BERT预训练模型的古诗生成器,支持诗歌风格定制,以及生成藏头诗。模型基于飞桨2.0框架,BERT预训练模型则调用自PaddleNLP,诗歌数据集采用Github开源数据集。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 相关内容介绍\n", + "\n", + "### PaddleNLP\n", + "\n", + "官网链接:[https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP](https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP)\n", + "\n", + "![](https://github.com/fiyen/models/raw/release/2.0-beta/PaddleNLP/docs/imgs/paddlenlp.png)\n", + "\n", + "\n", + "PaddleNLP旨在帮助开发者提高文本建模的效率,通过丰富的模型库、简洁易用的API,提供飞桨2.0的最佳实践并加速NLP领域应用产业落地效率。其产品特性如下:\n", + "\n", + "* **丰富的模型库**\n", + "\n", + "涵盖了NLP主流应用相关的前沿模型,包括中文词向量、预训练模型、词法分析、文本分类、文本匹配、文本生成、机器翻译、通用对话、问答系统等。\n", + "\n", + "* **简洁易用的API**\n", + "\n", + "深度兼容飞桨2.0的高层API体系,提供更多可复用的文本建模模块,可大幅度减少数据处理、组网、训练环节的代码开发,提高开发效率。\n", + "\n", + "* **高性能分布式训练**\n", + "\n", + "通过高度优化的Transformer网络实现,结合混合精度与Fleet分布式训练API,可充分利用GPU集群资源,高效完成预训练模型的分布式训练。\n", + "\n", + "### BERT\n", + "\n", + "BERT的全称为Bidirectional Encoder Representations from Transformers,即基于Transformers的双向编码表示模型。BERT是Transformers应用的一次巨大的成功。在该模型提出时,其在NLP领域的11个方向上都大幅刷新了SOTA。其模型的主要特点可以归纳如下:\n", + "\n", + "1. 基于Transformer。Transformer的提出将注意力机制的应用发挥到了极致,同时也解决了基于RNN的注意力机制的无法并行计算的问题,使超大规模的模型训练在时间上变得可以接受;\n", + "\n", + "2. 双向编码。其实双向编码不是BERT首创,但是基于Transformer与双向编码结合使这一做法的效用得到了最充分的发挥;\n", + "\n", + "3. 使用MLM(Mask Language Model)和NSP(Next Sentence Prediction)实现多任务训练的目标。\n", + "\n", + "4. 迁移学习。BERT模型展现出了大规模数据训练带来的有效性,而更重要的一点是,BERT实质上是一种更好的语义表征,相较于经典的Word2Vec,Glove等模型具有更好词嵌入特征。在实际应用中,我们可以直接调用训练好的BERT模型作为特征表示,进而设计下游任务。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 数据设置\n", + "在这一部分,我们对数据进行预处理,并构建训练用的数据读取器。\n", + "\n", + "## 数据准备\n", + "诗歌数据集采用Github上开源的[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)。在此,我们只使用其中的唐诗和宋诗的数据即可(json文件夹下)。" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 下载诗歌数据集 (从镜像网站github.com.cnpmjs.org下载可提高下载速度)\r\n", + "!git clone https://github.com.cnpmjs.org/chinese-poetry/chinese-poetry" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "此数据集中多数诗歌内容为繁体字,为了适应基于简体中文的预训练模型,我们对数据进行预处理,将繁体字转换为简体字。首先调用Github上开源的繁转简工具。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 下载繁体转简体工具\r\n", + "!git clone https://github.com.cnpmjs.org/fiyen/cht2chs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 数据处理\n", + "剔除数据集中的特殊符号,并将繁体转简体。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import os\r\n", + "import json\r\n", + "import re\r\n", + "from cht2chs.langconv import cht_to_chs\r\n", + "\r\n", + "def sentenceParse(para):\r\n", + " \"\"\"\r\n", + " 剔除诗歌字符中的非文字符号以及数字\r\n", + " \"\"\"\r\n", + " result, number = re.subn(u\"(.*)\", \"\", para)\r\n", + " result, number = re.subn(u\"{.*}\", \"\", result)\r\n", + " result, number = re.subn(u\"《.*》\", \"\", result)\r\n", + " result, number = re.subn(u\"《.*》\", \"\", result)\r\n", + " result, number = re.subn(u\"[\\]\\[]\", \"\", result)\r\n", + " r = \"\"\r\n", + " for s in result:\r\n", + " if s not in set('0123456789-'):\r\n", + " r += s\r\n", + " r, number = re.subn(u\"。。\", u\"。\", r)\r\n", + " return r\r\n", + "\r\n", + "\r\n", + "def data_preprocess(poem_dir='./chinese-poetry/json', len_limit=120):\r\n", + " \"\"\"\r\n", + " 预处理诗歌数据,返回符合要求的诗歌列表\r\n", + " \"\"\"\r\n", + " poems = []\r\n", + " for f in os.listdir(poem_dir):\r\n", + " if f.endswith('.json'):\r\n", + " json_data = json.load(open(os.path.join(poem_dir, f)))\r\n", + " for d in json_data:\r\n", + " try:\r\n", + " poem = ''.join(d['paragraphs'])\r\n", + " poem = sentenceParse(poem)\r\n", + " # 控制长度,并将繁体字转换为简体字\r\n", + " if len(poem) <= len_limit:\r\n", + " poems.append(cht_to_chs(poem))\r\n", + " except:\r\n", + " continue\r\n", + " return poems" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 开始处理\r\n", + "poems = data_preprocess()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "从PaddleNLP调用基于BERT预训练模型的分词工具,对诗歌进行分词和编码。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddlenlp.transformers import BertTokenizer\r\n", + "\r\n", + "bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“\\[CLS\\]”标记,在结尾添加“\\[SEP\\]”标记,这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“\\[UNK\\]”表示分词工具无法识别的符号,“\\[PAD\\]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "楚王台榭荆榛里,屈指江山俎豆中。\n", + "[101, 3504, 4374, 1378, 3531, 5769, 3527, 7027, 8024, 2235, 2900, 3736, 2255, 917, 6486, 704, 511, 102]\n", + "[CLS]楚王台榭荆榛里,屈指江山俎豆中。[SEP]\n", + "百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。\n", + "[101, 4636, 2399, 2129, 4373, 4767, 8024, 676, 7027, 5811, 2687, 740, 511, 1765, 2970, 5769, 7305, 6818, 8024, 4170, 6837, 3727, 3717, 7270, 511, 102]\n", + "[CLS]百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。[SEP]\n" + ] + } + ], + "source": [ + "# 处理效果展示\r\n", + "for poem in poems[6:8]:\r\n", + " token_poem, _ = bert_tokenizer.encode(poem).values()\r\n", + " print(poem)\r\n", + " print(token_poem)\r\n", + " print(''.join(bert_tokenizer.convert_ids_to_tokens(token_poem)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 构造数据读取器\n", + "预处理数据后,我们基于飞桨2.0构造数据读取器,以适应后续模型的训练。\n", + "\n", + "需注意以下类定义中包含填充内容,使输入样本对齐到一个特定的长度,以便于模型进行批处理运算。因此在得到数据读取器的实例时,需注意参数max_len,其不超过模型所支持的最大长度(PaddleNLP默认的序列最长长度为512)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import paddle\r\n", + "from paddle.io import Dataset\r\n", + "import numpy as np\r\n", + "\r\n", + "class PoemData(Dataset):\r\n", + " \"\"\"\r\n", + " 构造诗歌数据集,继承paddle.io.Dataset\r\n", + " Parameters:\r\n", + " poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码\r\n", + " max_len: 接收诗歌的最大长度\r\n", + " \"\"\"\r\n", + " def __init__(self, poems, tokenizer, max_len=128):\r\n", + " super(PoemData, self).__init__()\r\n", + " self.poems = poems\r\n", + " self.tokenizer = tokenizer\r\n", + " self.max_len = max_len\r\n", + " \r\n", + " def __getitem__(self, idx):\r\n", + " line = poems[idx]\r\n", + " token_line = self.tokenizer.encode(line)\r\n", + " token, token_type = token_line['input_ids'], token_line['token_type_ids']\r\n", + " if len(token) > self.max_len + 1:\r\n", + " token = token[:self.max_len] + token[-1]\r\n", + " token_type = token_type[:self.max_len] + token_type[-1]\r\n", + " input_token, input_token_type = token[:-1], token_type[:-1]\r\n", + " label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " # 输入填充\r\n", + " input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", + " input_pad_mask = (input_token != 0).astype('float32')\r\n", + " return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask\r\n", + " \r\n", + " def __len__(self):\r\n", + " return len(self.poems)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 模型设置与训练\n", + "在这一部分,我们将快速搭建基于BERT预训练模型的古诗生成器,并对模型进行训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 预训练BERT模型\n", + "古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。\n", + "\n", + "进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。因此,我们直接调用PaddleNLP的BERT词分类模型即可看,需注意模型分类的类别为词表长度。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddlenlp.transformers import BertModel, BertForTokenClassification\r\n", + "from paddle.nn import Layer, Linear, Softmax\r\n", + "\r\n", + "class PoetryBertModel(Layer):\r\n", + " \"\"\"\r\n", + " 基于BERT预训练模型的诗歌生成模型\r\n", + " \"\"\"\r\n", + " def __init__(self, pretrained_bert_model: str, input_length: int):\r\n", + " super(PoetryBertModel, self).__init__()\r\n", + " bert_model = BertModel.from_pretrained(pretrained_bert_model)\r\n", + " self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape\r\n", + " self.bert_for_class = BertForTokenClassification(bert_model, self.vocab_size)\r\n", + " # 生成下三角矩阵,用来mask句子后边的信息\r\n", + " self.sequence_length = input_length\r\n", + " self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", + "\r\n", + " def forward(self, token, token_type, input_mask, input_length=None):\r\n", + " # 计算attention mask\r\n", + " mask_left = paddle.reshape(input_mask, input_mask.shape + [1])\r\n", + " mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])\r\n", + " # 输入句子中有效的位置\r\n", + " mask_left = paddle.cast(mask_left, 'float32')\r\n", + " mask_right = paddle.cast(mask_right, 'float32')\r\n", + " attention_mask = paddle.matmul(mask_left, mask_right)\r\n", + " # 注意力机制计算中有效的位置\r\n", + " if input_length is not None:\r\n", + " lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", + " else:\r\n", + " lower_triangle_mask = self.lower_triangle_mask\r\n", + " attention_mask = attention_mask * lower_triangle_mask\r\n", + " # 无效的位置设为极小值\r\n", + " attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10\r\n", + " attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)\r\n", + "\r\n", + " output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)\r\n", + " \r\n", + " return output_logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 定义模型损失\n", + "由于真实值中有相当一部分是填充内容,我们需重写交叉熵损失,使其忽略填充内容带来的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class PoetryBertModelLossCriterion(Layer):\r\n", + " def forward(self, pred_logits, label, input_mask):\r\n", + " loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')\r\n", + " masked_loss = paddle.mean(loss * input_mask, axis=0)\r\n", + " return paddle.sum(masked_loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 模型准备\n", + "针对预训练模型的训练,需使用较小的学习率(learning_rate)进行调优。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2021-04-16 19:31:55,774] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams\n", + "[2021-04-16 19:32:00,118] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------------------------------------------------------------------------------\n", + " Layer (type) Input Shape Output Shape Param # \n", + "========================================================================================================================================\n", + " Embedding-16 [[1, 128]] [1, 128, 768] 16,226,304 \n", + " Embedding-17 [[1, 128]] [1, 128, 768] 393,216 \n", + " Embedding-18 [[1, 128]] [1, 128, 768] 1,536 \n", + " LayerNorm-129 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Dropout-188 [[1, 128, 768]] [1, 128, 768] 0 \n", + " BertEmbeddings-6 [] [1, 128, 768] 0 \n", + " Linear-371 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-372 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-373 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-374 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-61 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-190 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-130 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-375 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-189 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-376 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-191 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-131 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-61 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-377 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-378 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-379 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-380 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-62 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-193 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-132 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-381 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-192 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-382 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-194 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-133 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-62 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-383 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-384 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-385 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-386 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-63 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-196 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-134 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-387 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-195 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-388 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-197 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-135 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-63 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-389 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-390 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-391 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-392 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-64 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-199 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-136 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-393 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-198 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-394 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-200 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-137 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-64 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-395 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-396 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-397 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-398 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-65 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-202 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-138 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-399 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-201 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-400 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-203 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-139 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-65 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-401 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-402 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-403 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-404 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-66 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-205 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-140 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-405 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-204 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-406 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-206 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-141 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-66 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-407 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-408 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-409 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-410 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-67 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-208 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-142 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-411 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-207 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-412 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-209 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-143 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-67 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-413 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-414 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-415 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-416 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-68 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-211 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-144 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-417 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-210 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-418 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-212 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-145 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-68 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-419 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-420 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-421 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-422 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-69 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-214 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-146 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-423 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-213 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-424 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-215 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-147 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-69 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-425 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-426 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-427 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-428 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-70 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-217 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-148 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-429 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-216 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-430 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-218 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-149 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-70 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-431 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-432 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-433 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-434 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-71 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-220 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-150 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-435 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-219 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-436 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-221 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-151 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-71 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-437 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-438 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-439 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-440 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-72 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-223 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-152 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-441 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-222 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-442 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-224 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-153 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-72 [[1, 128, 768]] [1, 128, 768] 0 \n", + " TransformerEncoder-6 [[1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Linear-443 [[1, 768]] [1, 768] 590,592 \n", + " Tanh-7 [[1, 768]] [1, 768] 0 \n", + " BertPooler-6 [[1, 128, 768]] [1, 768] 0 \n", + " BertModel-6 [[1, 128]] [[1, 128, 768], [1, 768]] 0 \n", + " Dropout-225 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-444 [[1, 128, 768]] [1, 128, 21128] 16,247,432 \n", + "BertForTokenClassification-3 [[1, 128]] [1, 128, 21128] 0 \n", + "========================================================================================================================================\n", + "Total params: 118,515,080\n", + "Trainable params: 118,515,080\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------------------------------------------------------------------------------\n", + "Input size (MB): 0.00\n", + "Forward/backward pass size (MB): 219.04\n", + "Params size (MB): 452.10\n", + "Estimated Total Size (MB): 671.14\n", + "----------------------------------------------------------------------------------------------------------------------------------------\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "{'total_params': 118515080, 'trainable_params': 118515080}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from paddle.static import InputSpec\r\n", + "from paddlenlp.metrics import Perplexity\r\n", + "from paddle.optimizer import AdamW\r\n", + "\r\n", + "net = PoetryBertModel('bert-base-chinese', 128)\r\n", + "\r\n", + "token_ids = InputSpec((-1, 128), 'int64', 'token')\r\n", + "token_type_ids = InputSpec((-1, 128), 'int64', 'token_type')\r\n", + "input_mask = InputSpec((-1, 128), 'float32', 'input_mask')\r\n", + "label = InputSpec((-1, 128), 'int64', 'label')\r\n", + "\r\n", + "inputs = [token_ids, token_type_ids, input_mask]\r\n", + "labels = [label, input_mask]\r\n", + "\r\n", + "model = paddle.Model(net, inputs, labels)\r\n", + "model.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])\r\n", + "\r\n", + "model.summary(inputs, [input.dtype for input in inputs])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 模型训练\n", + "由于调用了预训练模型,再次调优,只需很少轮的训练即可达到较好的效果。\n", + "\n", + "训练过程中,设置save_dir参数来保存训练的模型,并通过save_freq设置保存的频率。" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddle.io import DataLoader\r\n", + "\r\n", + "train_loader = DataLoader(PoemData(poems, bert_tokenizer, 128), batch_size=128, shuffle=True)\r\n", + "model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 古诗生成\n", + "以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为下一个待预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。\n", + "\n", + "在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\r\n", + "\r\n", + "class PoetryGen(object):\r\n", + " \"\"\"\r\n", + " 定义一个自动生成诗句的类,按照要求生成诗句\r\n", + " model: 训练得到的预测模型\r\n", + " tokenizer: 分词编码工具\r\n", + " max_length: 生成诗句的最大长度,需小于等于model所允许的最大长度\r\n", + " \"\"\"\r\n", + " def __init__(self, model, tokenizer, max_length=512):\r\n", + " self.model = model\r\n", + " self.tokenizer = tokenizer\r\n", + " self.puncs = [',', '。', '?', ';']\r\n", + " self.max_length = max_length\r\n", + "\r\n", + " def generate(self, style='', head='', topk=2):\r\n", + " \"\"\"\r\n", + " 根据要求生成诗句\r\n", + " style (str): 生成诗句的风格,写成诗句的形式,如“大漠孤烟直,长河落日圆。”\r\n", + " head (str, list): 生成诗句的开头内容。若head为str格式,则head为诗句开始内容;\r\n", + " 若head为list格式,则head中每个元素为对应位置上诗句的开始内容(即藏头诗中的头)。\r\n", + " topk (int): 从预测的topk中选取结果\r\n", + " \"\"\"\r\n", + " head_index = 0\r\n", + " style_ids = self.tokenizer.encode(style)['input_ids']\r\n", + " # 去掉结束标记\r\n", + " style_ids = style_ids[:-1]\r\n", + " head_is_list = True if isinstance(head, list) else False\r\n", + " if head_is_list:\r\n", + " poetry_ids = self.tokenizer.encode(head[head_index])['input_ids']\r\n", + " else:\r\n", + " poetry_ids = self.tokenizer.encode(head)['input_ids']\r\n", + " # 去掉开始和结束标记\r\n", + " poetry_ids = poetry_ids[1:-1]\r\n", + " break_flag = False\r\n", + " while len(style_ids) + len(poetry_ids) <= self.max_length:\r\n", + " next_word = self._gen_next_word(style_ids + poetry_ids, topk)\r\n", + " # 对于一些符号,如[UNK], [PAD], [CLS]等,其产生后对诗句无意义,直接跳过\r\n", + " if next_word in self.tokenizer.convert_tokens_to_ids(['[UNK]', '[PAD]', '[CLS]']):\r\n", + " continue\r\n", + " if head_is_list:\r\n", + " if next_word in self.tokenizer.convert_tokens_to_ids(self.puncs):\r\n", + " head_index += 1\r\n", + " if head_index < len(head):\r\n", + " new_ids = self.tokenizer.encode(head[head_index])['input_ids']\r\n", + " new_ids = [next_word] + new_ids[1:-1]\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " break_flag = True\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " else:\r\n", + " new_ids = [next_word]\r\n", + " if next_word == self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0]:\r\n", + " break\r\n", + " poetry_ids += new_ids\r\n", + " if break_flag:\r\n", + " break\r\n", + " return ''.join(self.tokenizer.convert_ids_to_tokens(poetry_ids))\r\n", + "\r\n", + " def _gen_next_word(self, known_ids, topk):\r\n", + " type_token = [0] * len(known_ids)\r\n", + " mask = [1] * len(known_ids)\r\n", + " sequence_length = len(known_ids)\r\n", + " known_ids = paddle.to_tensor([known_ids], dtype='int64')\r\n", + " type_token = paddle.to_tensor([type_token], dtype='int64')\r\n", + " mask = paddle.to_tensor([mask], dtype='float32')\r\n", + " logits = self.model.network.forward(known_ids, type_token, mask, sequence_length)\r\n", + " # logits中对应最后一个词的输出即为下一个词的概率\r\n", + " words_prob = logits[0, -1, :].numpy()\r\n", + " # 依概率倒序排列后,选取前topk个词\r\n", + " words_to_be_choosen = words_prob.argsort()[::-1][:topk]\r\n", + " probs_to_be_choosen = words_prob[words_to_be_choosen]\r\n", + " # 归一化\r\n", + " probs_to_be_choosen = probs_to_be_choosen / sum(probs_to_be_choosen)\r\n", + " word_choosen = np.random.choice(words_to_be_choosen, p=probs_to_be_choosen)\r\n", + " return word_choosen\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## 生成古诗示例" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 载入已经训练好的模型\r\n", + "net = PoetryBertModel('bert-base-chinese', 128)\r\n", + "model = paddle.Model(net)\r\n", + "model.load('./checkpoint/final')\r\n", + "poetry_gen = PoetryGen(model, bert_tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def poetry_show(poetry):\r\n", + " pattern = r\"([,。;?])\"\r\n", + " text = re.sub(pattern, r'\\1 ', poetry)\r\n", + " for p in text.split():\r\n", + " if p:\r\n", + " print(p)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "一雨一晴天气新,\n", + "春风桃李不胜春。\n", + "山中老去无多事,\n", + "莫道山花不是真。\n", + "山色不随人意好,\n", + "花枝只与鸟情邻。\n", + "何时得见东君面,\n", + "共醉花光醉一身。\n" + ] + } + ], + "source": [ + "# 随机生成一首诗\r\n", + "poetry = poetry_gen.generate()\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "云外有时生,\n", + "云间无限好?\n", + "月明风细细,\n", + "松响竹萧悄。\n", + "谁识此时情?\n", + "相看情未了。\n" + ] + } + ], + "source": [ + "# 生成特定风格的诗\r\n", + "poetry = poetry_gen.generate(style='会当凌绝顶,一览众山小。')\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "好好学习子,\n", + "不如癡爱官。\n", + "一身无定价,\n", + "百事有馀安。\n" + ] + } + ], + "source": [ + "# 生成特定开头的诗\r\n", + "poetry = poetry_gen.generate(head='好好学习')\r\n", + "poetry_show(poetry)" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "飞鸿过眼疾于风,\n", + "桨去帆开水拍空,\n", + "真箇老农无一箇。\n", + "好将诗卷作渔翁。\n" + ] + } + ], + "source": [ + "# 生成藏头诗\r\n", + "poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])\r\n", + "poetry_show(poetry)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} From 4fc5af6dbf7c937ba17016584f54b110fedac14c Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 17 Apr 2021 21:01:36 +0800 Subject: [PATCH 09/14] Add files via upload rewrite some discriptions --- .../pretrained_bert_for_poetry_generation.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb index ec6794f2..31596ea2 100644 --- a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -213,7 +213,7 @@ "collapsed": false }, "source": [ - "处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“\\[CLS\\]”标记,在结尾添加“\\[SEP\\]”标记,这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“\\[UNK\\]”表示分词工具无法识别的符号,“\\[PAD\\]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。" + "处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“\\[CLS\\]”标记(“\\[CLS\\]”是对一些特殊任务的留空项,对于需要此项功能的并需要标记语句开始的情况,一般会再加上“\\[BOS\\]”),在结尾添加“\\[SEP\\]”标记(需要区分句子的编码中,这个标记用来将不同的句子隔开,结尾添加“\\[EOS\\]”),这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“\\[UNK\\]”表示分词工具无法识别的符号,“\\[PAD\\]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。" ] }, { From e175b3d684bc6c9160a07821be852867c2da7cff Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sun, 18 Apr 2021 17:18:04 +0800 Subject: [PATCH 10/14] Update pretrained_word_embeddings.ipynb --- .../pretrained_word_embeddings.ipynb | 158 ++++++++++-------- 1 file changed, 91 insertions(+), 67 deletions(-) diff --git a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb index a1004324..b1ae5b8e 100644 --- a/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb +++ b/paddle2.0_docs/pretrained_word_embeddings/pretrained_word_embeddings.ipynb @@ -8,10 +8,8 @@ "source": [ "# 使用预训练的词向量完成文本分类任务\n", "\n", - "**作者**: [fiyen](https://github.com/fiyen)\n", - "\n", - "**日期**: 2021.03\n", - "\n", + "**作者**: [fiyen](https://github.com/fiyen)
\n", + "**日期**: 2021.03
\n", "**摘要**: 本示例教程将会演示如何使用飞桨内置的Imdb数据集,并使用预训练词向量进行文本分类。" ] }, @@ -21,9 +19,9 @@ "collapsed": false }, "source": [ - "## 摘要\n", + "## 一、简介\n", "\n", - "在这个示例中,我们将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imdb将直接调用自飞桨2.0,同时,\n", + "在这个示例中,将使用飞桨2.0完成针对Imdb数据集(电影评论情感二分类数据集)的分类训练和测试。Imdb将直接调用自飞桨2.0,同时,\n", "利用预训练的词向量([GloVe embedding](http://nlp.stanford.edu/projects/glove/))完成任务。" ] }, @@ -33,7 +31,8 @@ "collapsed": false }, "source": [ - "## 环境设置" + "## 二、环境设置\n", + "本教程基于Paddle 2.0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) Paddle 2.0 。" ] }, { @@ -42,13 +41,23 @@ "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0.1\n" + ] + } + ], "source": [ "import paddle\r\n", "from paddle.io import Dataset\r\n", "import numpy as np\r\n", "import paddle.text as text\r\n", - "import random" + "import random\r\n", + "\r\n", + "print(paddle.__version__)" ] }, { @@ -57,12 +66,14 @@ "collapsed": false }, "source": [ - "## 用飞桨2.0调用Imdb数据集\n", - "由于飞桨2.0提供了经过处理的Imdb数据集,我们可以方便地调用所需要的数据实例,省去了数据预处理的麻烦。目前,飞桨2.0以及内置的高质量\n", + "## 三、用飞桨2.0调用Imdb数据集\n", + "由于飞桨2.0提供了经过处理的Imdb数据集,可以方便地调用所需要的数据实例,省去了数据预处理的麻烦。目前,飞桨2.0以及内置的高质量\n", "数据集包括Conll05st、Imdb、Imikolov、Movielens、HCIHousing、WMT14和WMT16等,未来还将提供更多常用数据集的调用接口。\n", "\n", "以下定义了调用imdb训练集合测试集的方法。其中,cutoff定义了构建词典的截止大小,即数据集中出现频率在cutoff以下的不予考虑;mode定义了返回的数据用于何种用途(test: \n", - "测试集,train: 训练集)。" + "测试集,train: 训练集)。\n", + "\n", + "### 3.1 定义数据集" ] }, { @@ -83,7 +94,7 @@ "collapsed": false }, "source": [ - "调用Imdb得到的是经过编码的数据集,每个term对应一个唯一id,映射关系可以通过imdb_train.word_idx查看。将每一个样本即一条电影评论,表示成id序列。我们可以检查一下以上生成的数据内容:" + "调用Imdb得到的是经过编码的数据集,每个term对应一个唯一id,映射关系可以通过imdb_train.word_idx查看。将每一个样本即一条电影评论,表示成id序列。可以检查一下以上生成的数据内容:" ] }, { @@ -119,7 +130,7 @@ "collapsed": false }, "source": [ - "对于训练集,我们将数据的顺序打乱,以优化将要进行的分类模型训练的效果。" + "对于训练集,将数据的顺序打乱,以优化将要进行的分类模型训练的效果。" ] }, { @@ -146,7 +157,7 @@ }, "source": [ "从样本长度上可以看到,每个样本的长度是不相同的。然而,在模型的训练过程中,需要保证每个样本的长度相同,以便于构造矩阵进行批量运算。\n", - "因此,我们需要先对所有样本进行填充或截断,使样本的长度一致。" + "因此,需要先对所有样本进行填充或截断,使样本的长度一致。" ] }, { @@ -172,7 +183,7 @@ "collapsed": false }, "source": [ - "## 载入预训练向量。\n", + "### 3.2 载入预训练向量\n", "以下给出的文件较小,可以直接完全载入内存。对于大型的预训练向量,无法一次载入内存的,可以采用分批载入,并行处理的方式进行匹配。" ] }, @@ -184,11 +195,10 @@ }, "outputs": [], "source": [ - "# 下载预训练向量文件,此链接下载较慢,推荐从AI Studio的公开数据集进行下载,此文件的下载请转网址:https://aistudio.baidu.com/aistudio/datasetdetail/42051\r\n", - "!wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", - "!unzip -q glove.6B.zip\r\n", + "# !wget http://nlp.stanford.edu/data/glove.6B.zip\r\n", + "# !unzip -q glove.6B.zip\r\n", "\r\n", - "glove_path = \"./glove.6B.100d.txt\" # 请修改至glove.6B.100d.txt所在位置\r\n", + "glove_path = \"./glove.6B.100d.txt\"\r\n", "embeddings = {}" ] }, @@ -198,7 +208,7 @@ "collapsed": false }, "source": [ - "我们先观察上述GloVe预训练向量文件一行的数据:" + "观察上述GloVe预训练向量文件一行的数据:" ] }, { @@ -265,13 +275,13 @@ "collapsed": false }, "source": [ - "## 给数据集的词表匹配词向量\n", - "接下来,我们提取数据集的词表,需要注意的是,词表中的词编码的先后顺序是按照词出现的频率排列的,频率越高的词编码值越小。" + "### 3.3 给数据集的词表匹配词向量\n", + "接下来,提取数据集的词表,需要注意的是,词表中的词编码的先后顺序是按照词出现的频率排列的,频率越高的词编码值越小。" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "collapsed": false }, @@ -298,15 +308,15 @@ "collapsed": false }, "source": [ - "观察词表的后5个单词,我们发现,最后一个词是\"\\\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", + "观察词表的后5个单词,发现最后一个词是\"\\\",这个符号代表所有词表以外的词。另外,对于形式b'the',是字符串'the'\n", "的二进制编码形式,使用中注意使用b'the'.decode()来进行转换('\\'并没有进行二进制编码,注意区分)。\n", - "接下来,我们给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,我们设该词的词\n", + "接下来,给词表中的每个词匹配对应的词向量。预训练词向量可能没有覆盖数据集词表中的所有词,对于没有的词,设该词的词\n", "向量为零向量。" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false }, @@ -329,14 +339,16 @@ "collapsed": false }, "source": [ - "## 构建基于预训练向量的Embedding\n", - "对于预训练向量的Embedding,我们一般期望它的参数不再变动,所以要设置trainable=False。如果希望在此基础上训练参数,则需要\n", + "## 四、组网\n", + "\n", + "### 4.1 构建基于预训练向量的Embedding\n", + "对于预训练向量的Embedding,一般期望它的参数不再变动,所以要设置trainable=False。如果希望在此基础上训练参数,则需要\n", "设置trainable=True。" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "collapsed": false }, @@ -357,15 +369,15 @@ "collapsed": false }, "source": [ - "## 构建分类器\n", - "这里,我们构建简单的基于一维卷积的分类模型,其结构为:Embedding->Conv1D->Pool1D->Linear。在定义Linear时,由于需要知\n", - "道输入向量的维度,我们可以按照公式[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-beta/api/paddle/nn/layer/conv/Conv2d_cn.html)\n", + "### 4.2 构建分类器\n", + "这里,构建简单的基于一维卷积的分类模型,其结构为:Embedding->Conv1D->Pool1D->Linear。在定义Linear时,由于需要知\n", + "道输入向量的维度,可以按照公式[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-beta/api/paddle/nn/layer/conv/Conv2d_cn.html)\n", "来进行计算。这里给出计算的函数如下:" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "collapsed": false }, @@ -378,12 +390,12 @@ " Layer (type) Input Shape Output Shape Param # \n", "===========================================================================\n", " Embedding-1 [[1, 2000]] [1, 2000, 100] 514,700 \n", - " Conv1D-1 [[1, 2000, 100]] [1, 998, 10] 5,010 \n", - " ReLU-1 [[1, 998, 10]] [1, 998, 10] 0 \n", - " MaxPool1D-1 [[1, 998, 10]] [1, 998, 5] 0 \n", - " Flatten-1 [[1, 998, 5]] [1, 4990] 0 \n", - " Linear-1 [[1, 4990]] [1, 2] 9,982 \n", - " Softmax-1 [[1, 2]] [1, 2] 0 \n", + " Conv1D-2 [[1, 2000, 100]] [1, 998, 10] 5,010 \n", + " ReLU-2 [[1, 998, 10]] [1, 998, 10] 0 \n", + " MaxPool1D-2 [[1, 998, 10]] [1, 998, 5] 0 \n", + " Flatten-3 [[1, 998, 5]] [1, 4990] 0 \n", + " Linear-2 [[1, 4990]] [1, 2] 9,982 \n", + " Softmax-2 [[1, 2]] [1, 2] 0 \n", "===========================================================================\n", "Total params: 529,692\n", "Trainable params: 14,992\n", @@ -403,7 +415,7 @@ "{'total_params': 529692, 'trainable_params': 14992}" ] }, - "execution_count": 12, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -442,13 +454,13 @@ "collapsed": false }, "source": [ - "## 读取数据,进行训练\n", - "我们可以利用飞桨2.0的io.Dataset模块来构建一个数据的读取器,方便地将数据进行分批训练。" + "### 4.3 读取数据,进行训练\n", + "可以利用飞桨2.0的io.Dataset模块来构建一个数据的读取器,方便地将数据进行分批训练。" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "metadata": { "collapsed": false }, @@ -459,64 +471,64 @@ "text": [ "The loss value printed in the log is the current step, and the metric is the average value of previous step.\n", "Epoch 1/10\n", - "step 586/586 [==============================] - loss: 0.7259 - acc: 0.7708 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4177 - acc: 0.8686 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.5196 - acc: 0.7006 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7767 - acc: 0.8152 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 2/10\n", - "step 586/586 [==============================] - loss: 0.5012 - acc: 0.8090 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4485 - acc: 0.8819 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.5776 - acc: 0.7886 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7343 - acc: 0.8150 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 3/10\n", - "step 586/586 [==============================] - loss: 0.5459 - acc: 0.8248 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4396 - acc: 0.8869 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4988 - acc: 0.8182 - 2ms/step \n", + "step 196/196 [==============================] - loss: 0.7379 - acc: 0.8117 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 4/10\n", - "step 586/586 [==============================] - loss: 0.4274 - acc: 0.8431 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4270 - acc: 0.8926 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4896 - acc: 0.8051 - 2ms/step \n", + "step 196/196 [==============================] - loss: 0.6714 - acc: 0.8141 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 5/10\n", - "step 586/586 [==============================] - loss: 0.4212 - acc: 0.8501 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.3806 - acc: 0.8984 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.5174 - acc: 0.8144 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7172 - acc: 0.8162 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 6/10\n", - "step 586/586 [==============================] - loss: 0.4084 - acc: 0.8605 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4466 - acc: 0.9028 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4635 - acc: 0.8266 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.6236 - acc: 0.8026 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 7/10\n", - "step 586/586 [==============================] - loss: 0.4713 - acc: 0.8697 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4378 - acc: 0.9090 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.5352 - acc: 0.8222 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7829 - acc: 0.8070 - 2ms/step \n", "Eval samples: 6250\n", "Epoch 8/10\n", - "step 586/586 [==============================] - loss: 0.5050 - acc: 0.8745 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4609 - acc: 0.9132 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4925 - acc: 0.8248 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7258 - acc: 0.8118 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 9/10\n", - "step 586/586 [==============================] - loss: 0.5348 - acc: 0.8832 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4499 - acc: 0.9164 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4936 - acc: 0.8078 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.8195 - acc: 0.8027 - 3ms/step \n", "Eval samples: 6250\n", "Epoch 10/10\n", - "step 586/586 [==============================] - loss: 0.5156 - acc: 0.8846 - 4ms/step \n", + "step 586/586 [==============================] - loss: 0.4540 - acc: 0.9212 - 4ms/step \n", "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 196/196 [==============================] - loss: 0.4882 - acc: 0.8286 - 3ms/step \n", + "step 196/196 [==============================] - loss: 0.7865 - acc: 0.8138 - 3ms/step \n", "Eval samples: 6250\n" ] } @@ -555,12 +567,12 @@ "collapsed": false }, "source": [ - "## 评估效果并用模型预测" + "## 五、评估效果并用模型预测" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "metadata": { "collapsed": false }, @@ -571,7 +583,7 @@ "text": [ "Eval begin...\n", "The loss value printed in the log is the current batch, and the metric is the average value of previous step.\n", - "step 782/782 [==============================] - loss: 0.4061 - acc: 0.8207 - 3ms/step \n", + "step 782/782 [==============================] - loss: 0.4408 - acc: 0.8085 - 2ms/step \n", "Eval samples: 25000\n", "Predict begin...\n", "step 10/10 [==============================] - 2ms/step \n", @@ -622,6 +634,18 @@ "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", "language": "python", "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" } }, "nbformat": 4, From 31dd25367ae86875172fad4c6ea80f82b25c1d2c Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sat, 5 Jun 2021 10:07:18 +0800 Subject: [PATCH 11/14] Add files via upload MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补充了对输入的详细说明;将数据集替换为飞桨官方数据集;补充了自动写诗的说明;对章节进行了划分和标序 --- ...retrained_bert_for_poetry_generation.ipynb | 606 ++++++++++-------- 1 file changed, 336 insertions(+), 270 deletions(-) diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb index 31596ea2..0ec491be 100644 --- a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -22,6 +22,12 @@ }, "source": [ "## 摘要\n", + "古诗,中华民族最高贵的文化瑰宝,在几千年文化传承中扮演着重要的角色。诗歌已经融入中华儿女的血脉之中,上到古稀之人,下到刚入学的孩童,都能随口吟诵一首诗出来。诗句的运用体现了古今诗人对文字运用的娴熟技艺,同时寄托着诗人深远的情思。诗句或优美或刚劲,或温婉或苍凉,让人在阅读诗歌的时候,如沐春风,身临其境。\n", + "\n", + "美好的诗歌让人心向往之,当我们的眼球接受了美好景物时,谁不曾有“此情此景,我想吟诗一首”的冲动,却限于实力张口息声,半晌想不出一个合适的表达。此时,如果我们有一个强大的诗歌生成工具,岂不美哉?\n", + "\n", + "没问题,通过飞桨,搭建一个古诗自动生成模型将不再是一个困难的事情。在这里,我们将展示如何用飞桨快速搭建一个强大的古诗生成模型。\n", + "\n", "在这个示例中,我们将快速构建基于BERT预训练模型的古诗生成器,支持诗歌风格定制,以及生成藏头诗。模型基于飞桨2.0框架,BERT预训练模型则调用自PaddleNLP,诗歌数据集采用Github开源数据集。" ] }, @@ -31,9 +37,9 @@ "collapsed": false }, "source": [ - "## 相关内容介绍\n", + "## 1. 相关内容介绍\n", "\n", - "### PaddleNLP\n", + "### 1.1 PaddleNLP\n", "\n", "官网链接:[https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP](https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP)\n", "\n", @@ -54,7 +60,7 @@ "\n", "通过高度优化的Transformer网络实现,结合混合精度与Fleet分布式训练API,可充分利用GPU集群资源,高效完成预训练模型的分布式训练。\n", "\n", - "### BERT\n", + "### 1.2 BERT\n", "\n", "BERT的全称为Bidirectional Encoder Representations from Transformers,即基于Transformers的双向编码表示模型。BERT是Transformers应用的一次巨大的成功。在该模型提出时,其在NLP领域的11个方向上都大幅刷新了SOTA。其模型的主要特点可以归纳如下:\n", "\n", @@ -73,23 +79,74 @@ "collapsed": false }, "source": [ - "# 数据设置\n", - "在这一部分,我们对数据进行预处理,并构建训练用的数据读取器。\n", + "## 2. 数据设置\n", + "在这一部分,我们将介绍使用的数据集,并展示数据集的调用方法。\n", "\n", - "## 数据准备\n", - "诗歌数据集采用Github上开源的[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)。在此,我们只使用其中的唐诗和宋诗的数据即可(json文件夹下)。" + "### 2.1 数据准备\n", + "诗歌数据集采用Github上开源的[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)。\n", + "\n", + "该数据集包含了唐宋两朝近一万四千古诗人, 接近5.5万首唐诗加26万宋诗. 两宋时期1564位词人,21050首词。其中,唐宋诗歌内容在json文件夹下,这里只使用json文件夹下的数据即可。以下式单个数据的示例:\n", + "```\n", + "{\n", + " \"author\":string\"胡宿\"\n", + " \"paragraphs\":[\n", + " \"五粒青松護翠苔,石門岑寂斷纖埃。\"\n", + " \"水浮花片知仙路,風遞鸞聲認嘯臺。\"\n", + " \"桐井曉寒千乳斂,茗園春嫩一旗開。\"\n", + " \"馳煙未勒山亭字,可是英靈許再來。\"\n", + " ]\n", + " \"title\":string\"沖虛觀\"\n", + " \"id\":string\"dad91d22-4b8a-4c04-a0d5-8f7ca8aff4de\"\n", + "}\n", + ",…]\n", + "```\n", + "\n", + "可见,此数据集中多数诗歌内容为繁体字。不过不用担心,飞桨已经内置了该数据集并且已经进行了简体化,我们可以通过简单的几行代码迅速调用该数据集!如下所示:" ] }, { "cell_type": "code", - "execution_count": 141, + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " def convert_to_list(value, n, name, dtype=np.int):\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test_dataset 的样本数量:364\n", + "dev_dataset 的样本数量:995\n", + "train_dataset 的样本数量:294598\n" + ] + } + ], + "source": [ + "import paddlenlp as ppnlp\r\n", + "test_dataset = ppnlp.datasets.Poetry('test')\r\n", + "dev_dataset = ppnlp.datasets.Poetry('dev')\r\n", + "train_dataset = ppnlp.datasets.Poetry('train')\r\n", + "print('test_dataset 的样本数量:%d'%len(test_dataset))\r\n", + "print('dev_dataset 的样本数量:%d'%len(dev_dataset))\r\n", + "print('train_dataset 的样本数量:%d'%len(train_dataset))" + ] + }, + { + "cell_type": "markdown", "metadata": { "collapsed": false }, - "outputs": [], "source": [ - "# 下载诗歌数据集 (从镜像网站github.com.cnpmjs.org下载可提高下载速度)\r\n", - "!git clone https://github.com.cnpmjs.org/chinese-poetry/chinese-poetry" + "以上三个数据,train_dataset为训练集,test_dataset为测试集,dev_dataset为开发集。其中开发集用于训练过程的测试,以用来选择最合适的模型参数,避免模型过拟合。" ] }, { @@ -98,7 +155,8 @@ "collapsed": false }, "source": [ - "此数据集中多数诗歌内容为繁体字,为了适应基于简体中文的预训练模型,我们对数据进行预处理,将繁体字转换为简体字。首先调用Github上开源的繁转简工具。" + "### 2.2 数据处理\n", + "如下为以上数据单样本的实例:" ] }, { @@ -107,10 +165,17 @@ "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "单样本示例:['西\\x02风\\x02簇\\x02浪\\x02花\\x02,\\x02太\\x02湖\\x02连\\x02底\\x02冻\\x02。', '冷\\x02照\\x02玉\\x02奁\\x02清\\x02,\\x02一\\x02片\\x02无\\x02瑕\\x02缝\\x02。\\x02面\\x02目\\x02分\\x02明\\x02,\\x02眼\\x02睛\\x02定\\x02动\\x02。\\x02不\\x02墯\\x02虚\\x02凝\\x02裂\\x02万\\x02差\\x02,\\x02漆\\x02桶\\x02漆\\x02桶\\x02。']\n" + ] + } + ], "source": [ - "# 下载繁体转简体工具\r\n", - "!git clone https://github.com.cnpmjs.org/fiyen/cht2chs" + "print('单样本示例:%s'%test_dataset[0])" ] }, { @@ -119,70 +184,46 @@ "collapsed": false }, "source": [ - "## 数据处理\n", - "剔除数据集中的特殊符号,并将繁体转简体。" + "从单个样本的实例中可以看到,每个样本都有两句。为了方便处理,这里我们直接将两句合成一句进行训练。训练中我们将用每个诗句当前的字去预测下一个字,假设我们有样本sample, 那么我们的输入为sample\\[:-1\\],要预测的目标为sample\\[1:\\]。诗句中每个字后边都有符号'\\x02',由于对当前的训练并没有帮助,所以我们将其替换掉。" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "import os\r\n", - "import json\r\n", "import re\r\n", - "from cht2chs.langconv import cht_to_chs\r\n", - "\r\n", - "def sentenceParse(para):\r\n", - " \"\"\"\r\n", - " 剔除诗歌字符中的非文字符号以及数字\r\n", - " \"\"\"\r\n", - " result, number = re.subn(u\"(.*)\", \"\", para)\r\n", - " result, number = re.subn(u\"{.*}\", \"\", result)\r\n", - " result, number = re.subn(u\"《.*》\", \"\", result)\r\n", - " result, number = re.subn(u\"《.*》\", \"\", result)\r\n", - " result, number = re.subn(u\"[\\]\\[]\", \"\", result)\r\n", - " r = \"\"\r\n", - " for s in result:\r\n", - " if s not in set('0123456789-'):\r\n", - " r += s\r\n", - " r, number = re.subn(u\"。。\", u\"。\", r)\r\n", - " return r\r\n", - "\r\n", - "\r\n", - "def data_preprocess(poem_dir='./chinese-poetry/json', len_limit=120):\r\n", - " \"\"\"\r\n", - " 预处理诗歌数据,返回符合要求的诗歌列表\r\n", - " \"\"\"\r\n", - " poems = []\r\n", - " for f in os.listdir(poem_dir):\r\n", - " if f.endswith('.json'):\r\n", - " json_data = json.load(open(os.path.join(poem_dir, f)))\r\n", - " for d in json_data:\r\n", - " try:\r\n", - " poem = ''.join(d['paragraphs'])\r\n", - " poem = sentenceParse(poem)\r\n", - " # 控制长度,并将繁体字转换为简体字\r\n", - " if len(poem) <= len_limit:\r\n", - " poems.append(cht_to_chs(poem))\r\n", - " except:\r\n", - " continue\r\n", - " return poems" + "def data_preprocess(dataset):\r\n", + " for i, data in enumerate(dataset):\r\n", + " dataset.data[i] = ''.join(dataset[i])\r\n", + " dataset.data[i] = re.sub('\\x02', '', dataset[i])\r\n", + " return dataset" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "处理后的单样本示例:西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。\n" + ] + } + ], "source": [ "# 开始处理\r\n", - "poems = data_preprocess()" + "test_dataset = data_preprocess(test_dataset)\r\n", + "dev_dataset = data_preprocess(dev_dataset)\r\n", + "train_dataset = data_preprocess(train_dataset)\r\n", + "print('处理后的单样本示例:%s'%test_dataset[0])" ] }, { @@ -196,11 +237,19 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2021-06-05 09:16:08,170] [ INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt\n" + ] + } + ], "source": [ "from paddlenlp.transformers import BertTokenizer\r\n", "\r\n", @@ -218,7 +267,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "collapsed": false }, @@ -227,18 +276,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "楚王台榭荆榛里,屈指江山俎豆中。\n", - "[101, 3504, 4374, 1378, 3531, 5769, 3527, 7027, 8024, 2235, 2900, 3736, 2255, 917, 6486, 704, 511, 102]\n", - "[CLS]楚王台榭荆榛里,屈指江山俎豆中。[SEP]\n", - "百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。\n", - "[101, 4636, 2399, 2129, 4373, 4767, 8024, 676, 7027, 5811, 2687, 740, 511, 1765, 2970, 5769, 7305, 6818, 8024, 4170, 6837, 3727, 3717, 7270, 511, 102]\n", - "[CLS]百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。[SEP]\n" + "西风簇浪花,太湖连底冻。冷照玉奁清,一片无瑕缝。面目分明,眼睛定动。不墯虚凝裂万差,漆桶漆桶。\n", + "[101, 6205, 7599, 5077, 3857, 5709, 8024, 1922, 3959, 6825, 2419, 1108, 511, 1107, 4212, 4373, 100, 3926, 8024, 671, 4275, 3187, 4442, 5361, 511, 7481, 4680, 1146, 3209, 8024, 4706, 4714, 2137, 1220, 511, 679, 100, 5994, 1125, 6162, 674, 2345, 8024, 4024, 3446, 4024, 3446, 511, 102]\n", + "[CLS]西风簇浪花,太湖连底冻。冷照玉[UNK]清,一片无瑕缝。面目分明,眼睛定动。不[UNK]虚凝裂万差,漆桶漆桶。[SEP]\n", + "大道分明在眼前,时人不会悮归泉。黄芽本是乾坤气,神水根基与汞连。\n", + "[101, 1920, 6887, 1146, 3209, 1762, 4706, 1184, 8024, 3198, 782, 679, 833, 100, 2495, 3787, 511, 7942, 5715, 3315, 3221, 746, 1787, 3698, 8024, 4868, 3717, 3418, 1825, 680, 3735, 6825, 511, 102]\n", + "[CLS]大道分明在眼前,时人不会[UNK]归泉。黄芽本是乾坤气,神水根基与汞连。[SEP]\n" ] } ], "source": [ "# 处理效果展示\r\n", - "for poem in poems[6:8]:\r\n", + "for poem in test_dataset[0:2]:\r\n", " token_poem, _ = bert_tokenizer.encode(poem).values()\r\n", " print(poem)\r\n", " print(token_poem)\r\n", @@ -251,15 +300,29 @@ "collapsed": false }, "source": [ - "## 构造数据读取器\n", + "### 2.3 构造数据读取器\n", "预处理数据后,我们基于飞桨2.0构造数据读取器,以适应后续模型的训练。\n", "\n", + "在构造读取器之前,我们先来了解一下BERT模型的输入是什么样子的。如下图所示:\n", + "\n", + "![](https://ai-studio-static-online.cdn.bcebos.com/d35a235024fe431e8a3c3cac62064836e53e78577d8b42d4aaa806821c781c98)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "上图中可以清晰地显示出输入数据的具体样式,包括三个部分:Token Embeddings, Segment Embeddings, Position Embeddings。在这里,Embeddings理解为嵌入,即将一个元素表示成一个1 * n的向量的形式,用以表示这个元素在一个向量空间的相对位置。这是中文文本处理如今比较普遍采用的方式。在这里,Token Embeddings为词嵌入,将分词后的词元素映射成一个个1 * n的向量。除此之外,Segment Embeddings表示每个词元素属于何种角色。具体来说,当我们需要区分一个输入中不同语句时,如在对话模型中,区分输入中每一句话是哪个对象发出的,可以用Segment Embeddings。Position Embeddings为Transformer类模型的特色,由于此类自注意力机制无法区分距离的远近,引入了该嵌入来增加距离产生的偏置。通常情况下,Position为一个从句首到句尾渐增的数列,如\\[0,1,2,3,4,5,...,n-1\\]即表示一个长度为n的输入的Position。如何得到Embeddings呢?通常是构造一个N * n的矩阵,所有元素被唯一对应一个位置索引,元素数量不大于N。每一个元素的嵌入通过其对应的索引调取矩阵对应的行的n个列上的元素,即1 * n的向量。在这个项目中,由于不需要区分每一句的角色,Segment Embeddings可以设为一样的,即索引都为相同的值 (如0)。由于飞桨的BERT模型会自动处理Segment Embeddings和Position Embeddings,在构造输入的时候我们可以忽略这两项。在进行下一步计算前,所有类型的进行加和,每个词元素对应一个合成的嵌入向量。\n", + "\n", "需注意以下类定义中包含填充内容,使输入样本对齐到一个特定的长度,以便于模型进行批处理运算。因此在得到数据读取器的实例时,需注意参数max_len,其不超过模型所支持的最大长度(PaddleNLP默认的序列最长长度为512)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, @@ -283,12 +346,12 @@ " self.max_len = max_len\r\n", " \r\n", " def __getitem__(self, idx):\r\n", - " line = poems[idx]\r\n", + " line = self.poems[idx]\r\n", " token_line = self.tokenizer.encode(line)\r\n", " token, token_type = token_line['input_ids'], token_line['token_type_ids']\r\n", " if len(token) > self.max_len + 1:\r\n", - " token = token[:self.max_len] + token[-1]\r\n", - " token_type = token_type[:self.max_len] + token_type[-1]\r\n", + " token = token[:self.max_len] + token[-1:]\r\n", + " token_type = token_type[:self.max_len] + token_type[-1:]\r\n", " input_token, input_token_type = token[:-1], token_type[:-1]\r\n", " label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')\r\n", " # 输入填充\r\n", @@ -307,7 +370,7 @@ "collapsed": false }, "source": [ - "# 模型设置与训练\n", + "## 3. 模型设置与训练\n", "在这一部分,我们将快速搭建基于BERT预训练模型的古诗生成器,并对模型进行训练。" ] }, @@ -317,7 +380,7 @@ "collapsed": false }, "source": [ - "## 预训练BERT模型\n", + "### 3.1 预训练BERT模型\n", "古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。\n", "\n", "进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。因此,我们直接调用PaddleNLP的BERT词分类模型即可看,需注意模型分类的类别为词表长度。" @@ -325,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "collapsed": false }, @@ -376,13 +439,13 @@ "collapsed": false }, "source": [ - "## 定义模型损失\n", + "### 3.2 定义模型损失\n", "由于真实值中有相当一部分是填充内容,我们需重写交叉熵损失,使其忽略填充内容带来的损失。" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "collapsed": false }, @@ -401,7 +464,7 @@ "collapsed": false }, "source": [ - "## 模型准备\n", + "### 3.3 模型准备\n", "针对预训练模型的训练,需使用较小的学习率(learning_rate)进行调优。" ] }, @@ -416,8 +479,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2021-04-16 19:31:55,774] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams\n", - "[2021-04-16 19:32:00,118] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n" + "[2021-06-05 09:16:08,229] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams\n", + "[2021-06-05 09:16:17,229] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/core/fromnumeric.py:87: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", + " return ufunc.reduce(obj, axis, dtype, out, **passkwargs)\n" ] }, { @@ -427,176 +492,176 @@ "----------------------------------------------------------------------------------------------------------------------------------------\n", " Layer (type) Input Shape Output Shape Param # \n", "========================================================================================================================================\n", - " Embedding-16 [[1, 128]] [1, 128, 768] 16,226,304 \n", - " Embedding-17 [[1, 128]] [1, 128, 768] 393,216 \n", - " Embedding-18 [[1, 128]] [1, 128, 768] 1,536 \n", - " LayerNorm-129 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Dropout-188 [[1, 128, 768]] [1, 128, 768] 0 \n", - " BertEmbeddings-6 [] [1, 128, 768] 0 \n", - " Linear-371 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-372 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-373 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-374 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-61 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-190 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-130 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-375 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-189 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-376 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-191 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-131 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-61 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-377 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-378 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-379 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-380 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-62 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-193 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-132 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-381 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-192 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-382 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-194 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-133 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-62 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-383 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-384 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-385 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-386 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-63 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-196 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-134 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-387 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-195 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-388 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-197 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-135 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-63 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-389 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-390 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-391 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-392 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-64 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-199 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-136 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-393 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-198 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-394 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-200 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-137 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-64 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-395 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-396 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-397 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-398 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-65 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-202 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-138 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-399 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-201 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-400 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-203 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-139 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-65 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-401 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-402 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-403 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-404 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-66 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-205 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-140 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-405 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-204 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-406 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-206 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-141 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-66 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-407 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-408 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-409 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-410 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-67 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-208 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-142 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-411 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-207 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-412 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-209 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-143 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-67 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-413 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-414 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-415 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-416 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-68 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-211 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-144 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-417 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-210 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-418 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-212 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-145 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-68 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-419 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-420 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-421 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-422 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-69 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-214 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-146 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-423 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-213 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-424 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-215 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-147 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-69 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-425 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-426 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-427 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-428 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-70 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-217 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-148 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-429 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-216 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-430 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-218 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-149 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-70 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-431 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-432 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-433 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-434 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-71 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-220 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-150 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-435 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-219 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-436 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-221 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-151 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-71 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-437 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-438 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-439 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " Linear-440 [[1, 128, 768]] [1, 128, 768] 590,592 \n", - " MultiHeadAttention-72 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Dropout-223 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-152 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " Linear-441 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", - " Dropout-222 [[1, 128, 3072]] [1, 128, 3072] 0 \n", - " Linear-442 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", - " Dropout-224 [[1, 128, 768]] [1, 128, 768] 0 \n", - " LayerNorm-153 [[1, 128, 768]] [1, 128, 768] 1,536 \n", - " TransformerEncoderLayer-72 [[1, 128, 768]] [1, 128, 768] 0 \n", - " TransformerEncoder-6 [[1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", - " Linear-443 [[1, 768]] [1, 768] 590,592 \n", - " Tanh-7 [[1, 768]] [1, 768] 0 \n", - " BertPooler-6 [[1, 128, 768]] [1, 768] 0 \n", - " BertModel-6 [[1, 128]] [[1, 128, 768], [1, 768]] 0 \n", - " Dropout-225 [[1, 128, 768]] [1, 128, 768] 0 \n", - " Linear-444 [[1, 128, 768]] [1, 128, 21128] 16,247,432 \n", - "BertForTokenClassification-3 [[1, 128]] [1, 128, 21128] 0 \n", + " Embedding-1 [[1, 128]] [1, 128, 768] 16,226,304 \n", + " Embedding-2 [[1, 128]] [1, 128, 768] 393,216 \n", + " Embedding-3 [[1, 128]] [1, 128, 768] 1,536 \n", + " LayerNorm-1 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Dropout-1 [[1, 128, 768]] [1, 128, 768] 0 \n", + " BertEmbeddings-1 [] [1, 128, 768] 0 \n", + " Linear-1 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-2 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-3 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-4 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-1 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-3 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-2 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-5 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-2 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-6 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-4 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-3 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-1 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-7 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-8 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-9 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-10 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-2 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-6 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-4 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-11 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-5 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-12 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-7 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-5 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-2 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-13 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-14 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-15 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-16 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-3 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-9 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-6 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-17 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-8 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-18 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-10 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-7 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-3 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-19 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-20 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-21 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-22 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-4 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-12 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-8 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-23 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-11 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-24 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-13 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-9 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-4 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-25 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-26 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-27 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-28 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-5 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-15 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-10 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-29 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-14 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-30 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-16 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-11 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-5 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-31 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-32 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-33 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-34 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-6 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-18 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-12 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-35 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-17 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-36 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-19 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-13 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-6 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-37 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-38 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-39 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-40 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-7 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-21 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-14 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-41 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-20 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-42 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-22 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-15 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-7 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-43 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-44 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-45 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-46 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-8 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-24 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-16 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-47 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-23 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-48 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-25 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-17 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-8 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-49 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-50 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-51 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-52 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-9 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-27 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-18 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-53 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-26 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-54 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-28 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-19 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-9 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-55 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-56 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-57 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-58 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-10 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-30 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-20 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-59 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-29 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-60 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-31 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-21 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-10 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-61 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-62 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-63 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-64 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-11 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-33 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-22 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-65 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-32 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-66 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-34 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-23 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-11 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-67 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-68 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-69 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " Linear-70 [[1, 128, 768]] [1, 128, 768] 590,592 \n", + " MultiHeadAttention-12 [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Dropout-36 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-24 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " Linear-71 [[1, 128, 768]] [1, 128, 3072] 2,362,368 \n", + " Dropout-35 [[1, 128, 3072]] [1, 128, 3072] 0 \n", + " Linear-72 [[1, 128, 3072]] [1, 128, 768] 2,360,064 \n", + " Dropout-37 [[1, 128, 768]] [1, 128, 768] 0 \n", + " LayerNorm-25 [[1, 128, 768]] [1, 128, 768] 1,536 \n", + " TransformerEncoderLayer-12 [[1, 128, 768]] [1, 128, 768] 0 \n", + " TransformerEncoder-1 [[1, 128, 768], [1, 1, 128, 128]] [1, 128, 768] 0 \n", + " Linear-73 [[1, 768]] [1, 768] 590,592 \n", + " Tanh-2 [[1, 768]] [1, 768] 0 \n", + " BertPooler-1 [[1, 128, 768]] [1, 768] 0 \n", + " BertModel-1 [[1, 128]] [[1, 128, 768], [1, 768]] 0 \n", + " Dropout-38 [[1, 128, 768]] [1, 128, 768] 0 \n", + " Linear-74 [[1, 128, 768]] [1, 128, 21128] 16,247,432 \n", + "BertForTokenClassification-1 [[1, 128]] [1, 128, 21128] 0 \n", "========================================================================================================================================\n", "Total params: 118,515,080\n", "Trainable params: 118,515,080\n", @@ -648,7 +713,7 @@ "collapsed": false }, "source": [ - "## 模型训练\n", + "### 3.4 模型训练\n", "由于调用了预训练模型,再次调优,只需很少轮的训练即可达到较好的效果。\n", "\n", "训练过程中,设置save_dir参数来保存训练的模型,并通过save_freq设置保存的频率。" @@ -656,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 12, "metadata": { "collapsed": false }, @@ -664,8 +729,9 @@ "source": [ "from paddle.io import DataLoader\r\n", "\r\n", - "train_loader = DataLoader(PoemData(poems, bert_tokenizer, 128), batch_size=128, shuffle=True)\r\n", - "model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1)" + "train_loader = DataLoader(PoemData(train_dataset, bert_tokenizer, 128), batch_size=128, shuffle=True)\r\n", + "dev_loader = DataLoader(PoemData(dev_dataset, bert_tokenizer, 128), batch_size=32, shuffle=True)\r\n", + "model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1, eval_data=dev_loader, eval_freq=1)" ] }, { @@ -674,7 +740,7 @@ "collapsed": false }, "source": [ - "# 古诗生成\n", + "## 4. 古诗生成\n", "以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为下一个待预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。\n", "\n", "在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。" @@ -682,7 +748,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false }, @@ -773,12 +839,12 @@ "collapsed": false }, "source": [ - "## 生成古诗示例" + "### 4.1 生成古诗示例" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "collapsed": false }, @@ -793,7 +859,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false }, @@ -809,7 +875,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "collapsed": false }, @@ -837,7 +903,7 @@ }, { "cell_type": "code", - "execution_count": 170, + "execution_count": null, "metadata": { "collapsed": false }, @@ -863,7 +929,7 @@ }, { "cell_type": "code", - "execution_count": 175, + "execution_count": null, "metadata": { "collapsed": false }, @@ -887,7 +953,7 @@ }, { "cell_type": "code", - "execution_count": 178, + "execution_count": null, "metadata": { "collapsed": false }, @@ -896,10 +962,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "飞鸿过眼疾于风,\n", - "桨去帆开水拍空,\n", - "真箇老农无一箇。\n", - "好将诗卷作渔翁。\n" + "飞来峰下白莲宫,\n", + "桨去帆来一叶东。\n", + "真境自然非世外,\n", + "好山长与白云通?\n" ] } ], From 7f25d54ff7bc69ea9b202beca080c7d895d0dffb Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Wed, 9 Jun 2021 15:41:03 +0800 Subject: [PATCH 12/14] Add files via upload update the newest url for paddlenlp; update the new API for paddlenlp. --- ...retrained_bert_for_poetry_generation.ipynb | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb index 0ec491be..68654f11 100644 --- a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -10,7 +10,7 @@ "\n", "**作者**:[fiyen](https://github.com/fiyen)\n", "\n", - "**日期**:2021.04\n", + "**日期**:2021.06\n", "\n", "**摘要**:本示例教程将会演示如何使用飞桨2.0以及PaddleNLP快速实现用BERT预训练模型生成高质量诗歌。" ] @@ -41,7 +41,7 @@ "\n", "### 1.1 PaddleNLP\n", "\n", - "官网链接:[https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP](https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP)\n", + "官网链接:[https://github.com/PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)\n", "\n", "![](https://github.com/fiyen/models/raw/release/2.0-beta/PaddleNLP/docs/imgs/paddlenlp.png)\n", "\n", @@ -106,7 +106,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# 更新paddlenlp版本\r\n", + "!pip install --upgrade paddlenlp" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": { "collapsed": false }, @@ -115,9 +127,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" + "2021-06-09 15:33:46,762 - INFO - unique_endpoints {''}\n", + "2021-06-09 15:33:46,763 - INFO - Downloading poetry.tar.gz from https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz\n", + "100%|██████████| 33059/33059 [00:00<00:00, 40500.87it/s]\n", + "2021-06-09 15:33:47,932 - INFO - File /home/aistudio/.paddlenlp/datasets/Poetry/poetry.tar.gz md5 checking...\n", + "2021-06-09 15:33:48,229 - INFO - Decompressing /home/aistudio/.paddlenlp/datasets/Poetry/poetry.tar.gz...\n" ] }, { @@ -131,10 +145,10 @@ } ], "source": [ - "import paddlenlp as ppnlp\r\n", - "test_dataset = ppnlp.datasets.Poetry('test')\r\n", - "dev_dataset = ppnlp.datasets.Poetry('dev')\r\n", - "train_dataset = ppnlp.datasets.Poetry('train')\r\n", + "import paddlenlp\r\n", + "test_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('test'), lazy=False)\r\n", + "dev_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('dev'), lazy=False)\r\n", + "train_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('train'), lazy=False)\r\n", "print('test_dataset 的样本数量:%d'%len(test_dataset))\r\n", "print('dev_dataset 的样本数量:%d'%len(dev_dataset))\r\n", "print('train_dataset 的样本数量:%d'%len(train_dataset))" @@ -161,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false }, @@ -170,7 +184,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "单样本示例:['西\\x02风\\x02簇\\x02浪\\x02花\\x02,\\x02太\\x02湖\\x02连\\x02底\\x02冻\\x02。', '冷\\x02照\\x02玉\\x02奁\\x02清\\x02,\\x02一\\x02片\\x02无\\x02瑕\\x02缝\\x02。\\x02面\\x02目\\x02分\\x02明\\x02,\\x02眼\\x02睛\\x02定\\x02动\\x02。\\x02不\\x02墯\\x02虚\\x02凝\\x02裂\\x02万\\x02差\\x02,\\x02漆\\x02桶\\x02漆\\x02桶\\x02。']\n" + "单样本示例:{'tokens': '西\\x02风\\x02簇\\x02浪\\x02花\\x02,\\x02太\\x02湖\\x02连\\x02底\\x02冻\\x02。', 'labels': '冷\\x02照\\x02玉\\x02奁\\x02清\\x02,\\x02一\\x02片\\x02无\\x02瑕\\x02缝\\x02。\\x02面\\x02目\\x02分\\x02明\\x02,\\x02眼\\x02睛\\x02定\\x02动\\x02。\\x02不\\x02墯\\x02虚\\x02凝\\x02裂\\x02万\\x02差\\x02,\\x02漆\\x02桶\\x02漆\\x02桶\\x02。'}\n" ] } ], @@ -189,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false }, @@ -198,14 +212,14 @@ "import re\r\n", "def data_preprocess(dataset):\r\n", " for i, data in enumerate(dataset):\r\n", - " dataset.data[i] = ''.join(dataset[i])\r\n", + " dataset.data[i] = ''.join(list(dataset[i].values()))\r\n", " dataset.data[i] = re.sub('\\x02', '', dataset[i])\r\n", " return dataset" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "collapsed": false }, @@ -237,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -246,7 +260,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2021-06-05 09:16:08,170] [ INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt\n" + "[2021-06-09 15:35:07,276] [ INFO] - Downloading bert-base-chinese-vocab.txt from https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-chinese-vocab.txt\n", + "100%|██████████| 107/107 [00:00<00:00, 23081.18it/s]\n" ] } ], @@ -267,7 +282,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "collapsed": false }, From fbec11f55ad20941d5fb12e8bb52ece96bd7cb80 Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Thu, 10 Jun 2021 12:29:43 +0800 Subject: [PATCH 13/14] Add files via upload merge op of load_dataset for test, dev, and train. --- .../pretrained_bert_for_poetry_generation.ipynb | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb index 68654f11..43829854 100644 --- a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -118,22 +118,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-06-09 15:33:46,762 - INFO - unique_endpoints {''}\n", - "2021-06-09 15:33:46,763 - INFO - Downloading poetry.tar.gz from https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz\n", - "100%|██████████| 33059/33059 [00:00<00:00, 40500.87it/s]\n", - "2021-06-09 15:33:47,932 - INFO - File /home/aistudio/.paddlenlp/datasets/Poetry/poetry.tar.gz md5 checking...\n", - "2021-06-09 15:33:48,229 - INFO - Decompressing /home/aistudio/.paddlenlp/datasets/Poetry/poetry.tar.gz...\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -146,9 +135,7 @@ ], "source": [ "import paddlenlp\r\n", - "test_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('test'), lazy=False)\r\n", - "dev_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('dev'), lazy=False)\r\n", - "train_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('train'), lazy=False)\r\n", + "test_dataset, dev_dataset, train_dataset = paddlenlp.datasets.load_dataset('poetry', splits=('test','dev','train'), lazy=False)\r\n", "print('test_dataset 的样本数量:%d'%len(test_dataset))\r\n", "print('dev_dataset 的样本数量:%d'%len(dev_dataset))\r\n", "print('train_dataset 的样本数量:%d'%len(train_dataset))" From 9fa7c4caa9e45e1512e795a9fc09e6a3be57361c Mon Sep 17 00:00:00 2001 From: WinSun <35953131+fiyen@users.noreply.github.com> Date: Sun, 27 Jun 2021 20:46:00 +0800 Subject: [PATCH 14/14] Add files via upload add more descriptions. --- ...retrained_bert_for_poetry_generation.ipynb | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb index 43829854..c103b616 100644 --- a/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb +++ b/paddle2.0_docs/pretrained_bert_for_poetry_generation/pretrained_bert_for_poetry_generation.ipynb @@ -385,7 +385,18 @@ "### 3.1 预训练BERT模型\n", "古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。\n", "\n", - "进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。因此,我们直接调用PaddleNLP的BERT词分类模型即可看,需注意模型分类的类别为词表长度。" + "进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。下表为一个示例:对于诗句“床前明月光,疑是地上霜。”来说,在训练的时候,输入为“床前明月光,疑是地上霜”(注意没有“。”),而预测的内容为输入的每个词对应的标签,我们把其预测标签设置为“前明月光,疑是地上霜。”在这里,我们可以理解为,文字“床”对应的标签为“前”、文字“前”对应的标签为“明”、......、文字“霜”对应的标签为“。”。因此,我们直接调用PaddleNLP的BERT词分类模型即可,需注意模型分类的类别为词表长度。\n", + "\n", + "|句子|床前明月光,疑是地上霜。|\n", + "|:--:|:--:|\n", + "|输入|床前明月光,疑是地上霜|\n", + "|预测|前明月光,疑是地上霜。|\n", + "|流程如下||\n", + "|根据内容:床|预测内容:前|\n", + "|根据内容:床前|预测内容:明|\n", + "|根据内容:床前明|预测内容:月|\n", + "|......|......|\n", + "|根据内容:床前明月光,疑是地上霜|预测内容:。|\n" ] }, { @@ -410,6 +421,9 @@ " self.bert_for_class = BertForTokenClassification(bert_model, self.vocab_size)\r\n", " # 生成下三角矩阵,用来mask句子后边的信息\r\n", " self.sequence_length = input_length\r\n", + " # lower_triangle_mask为input_length * input_length的下三角矩阵(包含主对角线),该掩码作为注意力掩码的一部分(在forward的\r\n", + " # 处理中为0的部分会被处理成无穷小量,以方便在计算注意力权重的时候保证被掩盖的部分权重约等于0)。而之所以写为下三角矩阵的形式,与\r\n", + " # transformer的多头注意力计算的机制有关,细节可以了解相关论文获悉。\r\n", " self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", "\r\n", " def forward(self, token, token_type, input_mask, input_length=None):\r\n", @@ -422,6 +436,8 @@ " attention_mask = paddle.matmul(mask_left, mask_right)\r\n", " # 注意力机制计算中有效的位置\r\n", " if input_length is not None:\r\n", + " # 之所以要再计算一次,是因为用于推理预测时,可能输入的长度不为实例化时设置的长度。这里的模型在训练时假设输入的\r\n", + " # 长度是被填充成一致的——这一步不是必须的,但是处理成一致长度比较方便处理(对应地,增加了显存的用度)。\r\n", " lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))\r\n", " else:\r\n", " lower_triangle_mask = self.lower_triangle_mask\r\n", @@ -743,7 +759,7 @@ }, "source": [ "## 4. 古诗生成\n", - "以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为下一个待预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。\n", + "以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为根据当前内容要预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。\n", "\n", "在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。" ]