当前位置:   article > 正文

textRNN & textCNN(及代码实现)

textrnn

1. 什么是textRNN

textRNN指的是利用RNN循环神经网络解决文本分类问题,文本分类是自然语言处理的一个基本任务,试图推断出给定文本(句子、文档等)的标签或标签集合。

文本分类的应用非常广泛,如:

  • 垃圾邮件分类:2分类问题,判断邮件是否为垃圾邮件
  • 情感分析:2分类问题:判断文本情感是积极还是消极;多分类问题:判断文本情感属于{非常消极,消极,中立,积极,非常积极}中的哪一类。
  • 新闻主题分类:判断一段新闻属于哪个类别,如财经、体育、娱乐等。根据类别标签的数量,可以是2分类也可以是多分类。
  • 自动问答系统中的问句分类
  • 社区问答系统中的问题分类:多标签多分类(对一段文本进行多分类,该文本可能有多个标签),如知乎看山杯
  • 让AI做法官:基于案件事实描述文本的罚金等级分类(多分类)和法条分类(多标签多分类)
  • 判断新闻是否为机器人所写:2分类

1.1 textRNN的原理

在一些自然语言处理任务中,当对序列进行处理时,我们一般会采用循环神经网络RNN,尤其是它的一些变种,如LSTM(更常用),GRU。当然我们也可以把RNN运用到文本分类任务中。

这里的文本可以一个句子,文档(短文本,若干句子)或篇章(长文本),因此每段文本的长度都不尽相同。在对文本进行分类时,我们一般会指定一个固定的输入序列/文本长度:该长度可以是最长文本/序列的长度,此时其他所有文本/序列都要进行填充以达到该长度;该长度也可以是训练集中所有文本/序列长度的均值,此时对于过长的文本/序列需要进行截断,过短的文本则进行填充。总之,要使得训练集中所有的文本/序列长度相同,该长度除之前提到的设置外,也可以是其他任意合理的数值。在测试时,也需要对测试集中的文本/序列做同样的处理。

首先我们需要对文本进行分词,然后指定一个序列长度n(大于n的截断,小于n的填充),并使用词嵌入得到每个词固定维度的向量表示。对于每一个输入文本/序列,我们可以在RNN的每一个时间步长上输入文本中一个单词的向量表示,计算当前时间步长上的隐藏状态,然后用于当前时间步骤的输出以及传递给下一个时间步长并和下一个单词的词向量一起作为RNN单元输入,然后再计算下一个时间步长上RNN的隐藏状态,以此重复...直到处理完输入文本中的每一个单词,由于输入文本的长度为n,所以要经历n个时间步长。

基于RNN的文本分类模型非常灵活,有多种多样的结构。接下来,我们主要介绍两种典型的结构。

2. textRNN网络结构

2.1 structure 1

流程:embedding--->BiLSTM--->concat final output/average all output----->softmax layer

结构图如下图所示:

一般取前向/反向LSTM在最后一个时间步长上隐藏状态,然后进行拼接,在经过一个softmax层(输出层使用softmax激活函数)进行一个多分类;或者取前向/反向LSTM在每一个时间步长上的隐藏状态,对每一个时间步长上的两个隐藏状态进行拼接,然后对所有时间步长上拼接后的隐藏状态取均值,再经过一个softmax层(输出层使用softmax激活函数)进行一个多分类(2分类的话使用sigmoid激活函数)。

上述结构也可以添加dropout/L2正则化或BatchNormalization 来防止过拟合以及加速模型训练。

2.2 structure 2

流程:embedding-->BiLSTM---->(dropout)-->concat ouput--->UniLSTM--->(droput)-->softmax layer

结构图如下图所示:

 

与之前结构不同的是,在双向LSTM(上图不太准确,底层应该是一个双向LSTM)的基础上又堆叠了一个单向的LSTM。把双向LSTM在每一个时间步长上的两个隐藏状态进行拼接,作为上层单向LSTM每一个时间步长上的一个输入,最后取上层单向LSTM最后一个时间步长上的隐藏状态,再经过一个softmax层(输出层使用softamx激活函数,2分类的话则使用sigmoid)进行一个多分类。

2.3 总结

TextRNN的结构非常灵活,可以任意改变。比如把LSTM单元替换为GRU单元,把双向改为单向,添加dropout或BatchNormalization以及再多堆叠一层等等。TextRNN在文本分类任务上的效果非常好,与TextCNN不相上下,但RNN的训练速度相对偏慢,一般2层就已经足够多了。

3. 什么是textCNN

在“卷积神经⽹络”中我们探究了如何使⽤⼆维卷积神经⽹络来处理⼆维图像数据。在之前的语⾔模型和⽂本分类任务中,我们将⽂本数据看作是只有⼀个维度的时间序列,并很⾃然地使⽤循环神经⽹络来表征这样的数据。其实,我们也可以将⽂本当作⼀维图像,从而可以⽤⼀维卷积神经⽹络来捕捉临近词之间的关联。本节将介绍将卷积神经⽹络应⽤到⽂本分析的开创性⼯作之⼀:textCNN

3.1 ⼀维卷积层

在介绍模型前我们先来解释⼀维卷积层的⼯作原理。与⼆维卷积层⼀样,⼀维卷积层使⽤⼀维的互相关运算。在⼀维互相关运算中,卷积窗口从输⼊数组的最左⽅开始,按从左往右的顺序,依次在输⼊数组上滑动。当卷积窗口滑动到某⼀位置时,窗口中的输⼊⼦数组与核数组按元素相乘并求和,得到输出数组中相应位置的元素。如下图所⽰,输⼊是⼀个宽为7的⼀维数组,核数组的宽为2。可以看到输出的宽度为 7 - 2 + 1 = 6,且第⼀个元素是由输⼊的最左边的宽为2的⼦数组与核数组按元素相乘后再相加得到的:0 × 1 + 1 × 2 = 2。

 

多输⼊通道的⼀维互相关运算也与多输⼊通道的⼆维互相关运算类似:在每个通道上,将核与相应的输⼊做⼀维互相关运算,并将通道之间的结果相加得到输出结果。下图展⽰了含3个输⼊ 通道的⼀维互相关运算,其中阴影部分为第⼀个输出元素及其计算所使⽤的输⼊和核数组元素: 0 × 1 + 1 × 2 + 1 × 3 + 2 × 4 + 2 × (-1) + 3 × (-3) = 2。

 

由⼆维互相关运算的定义可知,多输⼊通道的⼀维互相关运算可以看作单输⼊通道的⼆维互相关运算。如下图所⽰,我们也可以将上图中多输⼊通道的⼀维互相关运算以等价的单输⼊通道的⼆维互相关运算呈现。这⾥核的⾼等于输⼊的⾼。下图的阴影部分为第⼀个输出元素及其计算所使⽤的输⼊和核数组元素:2 × (-1) + 3 × (-3) + 1 × 3 + 2 × 4 + 0 × 1 + 1 × 2 = 2。

 

以上都是输出都只有⼀个通道。我们在“多输⼊通道和多输出通道”⼀节中介绍了如何在⼆维卷积层中指定多个输出通道。类似地,我们也可以在⼀维卷积层指定多个输出通道,从而拓展卷积层中的模型参数。

3. 2 时序最⼤池化层

类似地,我们有⼀维池化层。textCNN中使⽤的时序最⼤池化(max-over-time pooling)层实际上对应⼀维全局最⼤池化层:假设输⼊包含多个通道,各通道由不同时间步上的数值组成,各通道的输出即该通道所有时间步中最⼤的数值。因此,时序最⼤池化层的输⼊在各个通道上的时间步数可以不同。为提升计算性能,我们常常将不同⻓度的时序样本组成⼀个小批量,并通过在较短序列后附加特殊字符(如0)令批量中各时序样本⻓度相同。这些⼈为添加的特殊字符当然是⽆意义的。由于时序最⼤池化的主要⽬的是抓取时序中最重要的特征,它通常能使模型不受⼈为添加字符的影响。

3.3 textCNN模型

textCNN模型主要使⽤了⼀维卷积层和时序最⼤池化层。假设输⼊的⽂本序列由n个词组成,每个词⽤d维的词向量表⽰。那么输⼊样本的宽为n,⾼为1,输⼊通道数为d。textCNN的计算主要分为以下⼏步:

  1. 定义多个⼀维卷积核,并使⽤这些卷积核对输⼊分别做卷积计算。宽度不同的卷积核可能会捕捉到不同个数的相邻词的相关性。
  2. 对输出的所有通道分别做时序最⼤池化,再将这些通道的池化输出值连结为向量。
  3. 通过全连接层将连结后的向量变换为有关各类别的输出。这⼀步可以使⽤丢弃层应对过拟合。

下图⽤⼀个例⼦解释了textCNN的设计。这⾥的输⼊是⼀个有11个词的句⼦,每个词⽤6维词向量表⽰。因此输⼊序列的宽为11,输⼊通道数为6。给定2个⼀维卷积核,核宽分别为2和4,输出通道数分别设为4和5。因此,⼀维卷积计算后,4个输出通道的宽为 11 - 2 + 1 = 10,而其他5个通道的宽为 11 - 4 + 1 = 8。尽管每个通道的宽不同,我们依然可以对各个通道做时序最⼤池化,并将9个通道的池化输出连结成⼀个9维向量。最终,使⽤全连接将9维向量变换为2维输出,即正⾯情感和负⾯情感的预测。

 

4. 代码实现

4.1.textRNN实现新闻分类

数据集下载

使用THUCNews的一个子集进行训练与测试:https://www.lanzous.com/i5t0lsd

本次训练使用了其中的10个分类,每个分类6500条数据。类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

cnews_loader.py为数据的预处理文件。

  • read_file(): 读取文件数据;
  • build_vocab(): 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理;
  • read_vocab(): 读取上一步存储的词汇表,转换为{词:id}表示;
  • read_category(): 将分类目录固定,转换为{类别: id}表示;
  • to_words(): 将一条由id表示的数据重新转换为文字;
  • process_file(): 将数据集从文字转换为固定长度的id序列表示;
  • batch_iter(): 为神经网络的训练准备经过shuffle的批次的数据。

textRNN模型和可配置的参数,在rnn_model.py中。

  1. from __future__ import print_function
  2. import os
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. import numpy as np
  7. import tensorflow as tf
  8. from sklearn import metrics
  9. from rnn_model import TRNNConfig, TextRNN
  10. from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
  1. base_dir = 'cnews'
  2. train_dir = os.path.join(base_dir, 'cnews.train.txt')
  3. test_dir = os.path.join(base_dir, 'cnews.test.txt')
  4. val_dir = os.path.join(base_dir, 'cnews.val.txt')
  5. vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
  6. save_dir = 'checkpoints/textrnn'
  7. save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径
  8. def get_time_dif(start_time):
  9. """获取已使用时间"""
  10. end_time = time.time()
  11. time_dif = end_time - start_time
  12. return timedelta(seconds=int(round(time_dif)))
  13. def feed_data(x_batch, y_batch, keep_prob):
  14. feed_dict = {
  15. model.input_x: x_batch,
  16. model.input_y: y_batch,
  17. model.keep_prob: keep_prob
  18. }
  19. return feed_dict
  20. def evaluate(sess, x_, y_):
  21. """评估在某一数据上的准确率和损失"""
  22. data_len = len(x_)
  23. batch_eval = batch_iter(x_, y_, 128)
  24. total_loss = 0.0
  25. total_acc = 0.0
  26. for x_batch, y_batch in batch_eval:
  27. batch_len = len(x_batch)
  28. feed_dict = feed_data(x_batch, y_batch, 1.0)
  29. loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
  30. total_loss += loss * batch_len
  31. total_acc += acc * batch_len
  32. return total_loss / data_len, total_acc / data_len
  33. def train():
  34. print("Configuring TensorBoard and Saver...")
  35. # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
  36. tensorboard_dir = 'tensorboard/textrnn'
  37. if not os.path.exists(tensorboard_dir):
  38. os.makedirs(tensorboard_dir)
  39. tf.summary.scalar("loss", model.loss)
  40. tf.summary.scalar("accuracy", model.acc)
  41. merged_summary = tf.summary.merge_all()
  42. writer = tf.summary.FileWriter(tensorboard_dir)
  43. # 配置 Saver
  44. saver = tf.train.Saver()
  45. if not os.path.exists(save_dir):
  46. os.makedirs(save_dir)
  47. print("Loading training and validation data...")
  48. # 载入训练集与验证集
  49. start_time = time.time()
  50. x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
  51. x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
  52. time_dif = get_time_dif(start_time)
  53. print("Time usage:", time_dif)
  54. # 创建session
  55. session = tf.Session()
  56. session.run(tf.global_variables_initializer())
  57. writer.add_graph(session.graph)
  58. print('Training and evaluating...')
  59. start_time = time.time()
  60. total_batch = 0 # 总批次
  61. best_acc_val = 0.0 # 最佳验证集准确率
  62. last_improved = 0 # 记录上一次提升批次
  63. require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练
  64. flag = False
  65. for epoch in range(config.num_epochs):
  66. print('Epoch:', epoch + 1)
  67. batch_train = batch_iter(x_train, y_train, config.batch_size)
  68. for x_batch, y_batch in batch_train:
  69. feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
  70. if total_batch % config.save_per_batch == 0:
  71. # 每多少轮次将训练结果写入tensorboard scalar
  72. s = session.run(merged_summary, feed_dict=feed_dict)
  73. writer.add_summary(s, total_batch)
  74. if total_batch % config.print_per_batch == 0:
  75. # 每多少轮次输出在训练集和验证集上的性能
  76. feed_dict[model.keep_prob] = 1.0
  77. loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
  78. loss_val, acc_val = evaluate(session, x_val, y_val) # todo
  79. if acc_val > best_acc_val:
  80. # 保存最好结果
  81. best_acc_val = acc_val
  82. last_improved = total_batch
  83. saver.save(sess=session, save_path=save_path)
  84. improved_str = '*'
  85. else:
  86. improved_str = ''
  87. time_dif = get_time_dif(start_time)
  88. msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
  89. + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
  90. print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
  91. feed_dict[model.keep_prob] = config.dropout_keep_prob
  92. session.run(model.optim, feed_dict=feed_dict) # 运行优化
  93. total_batch += 1
  94. if total_batch - last_improved > require_improvement:
  95. # 验证集正确率长期不提升,提前结束训练
  96. print("No optimization for a long time, auto-stopping...")
  97. flag = True
  98. break # 跳出循环
  99. if flag: # 同上
  100. break
  101. def test():
  102. print("Loading test data...")
  103. start_time = time.time()
  104. x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
  105. session = tf.Session()
  106. session.run(tf.global_variables_initializer())
  107. saver = tf.train.Saver()
  108. saver.restore(sess=session, save_path=save_path) # 读取保存的模型
  109. print('Testing...')
  110. loss_test, acc_test = evaluate(session, x_test, y_test)
  111. msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
  112. print(msg.format(loss_test, acc_test))
  113. batch_size = 128
  114. data_len = len(x_test)
  115. num_batch = int((data_len - 1) / batch_size) + 1
  116. y_test_cls = np.argmax(y_test, 1)
  117. y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果
  118. for i in range(num_batch): # 逐批次处理
  119. start_id = i * batch_size
  120. end_id = min((i + 1) * batch_size, data_len)
  121. feed_dict = {
  122. model.input_x: x_test[start_id:end_id],
  123. model.keep_prob: 1.0
  124. }
  125. y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
  126. # 评估
  127. print("Precision, Recall and F1-Score...")
  128. print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
  129. # 混淆矩阵
  130. print("Confusion Matrix...")
  131. cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
  132. print(cm)
  133. time_dif = get_time_dif(start_time)
  134. print("Time usage:", time_dif)

模型训练

  1. type_ = 'train'
  2. print('Configuring RNN model...')
  3. config = TRNNConfig()
  4. if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建
  5. build_vocab(train_dir, vocab_dir, config.vocab_size)
  6. categories, cat_to_id = read_category()
  7. words, word_to_id = read_vocab(vocab_dir)
  8. config.vocab_size = len(words)
  9. model = TextRNN(config)
  10. if type_ == 'train':
  11. train()
  12. else:
  13. test()

模型测试

test()

4.2.textCNN实现新闻分类

数据集下载

使用THUCNews的一个子集进行训练与测试:https://www.lanzous.com/i5t0lsd

本次训练使用了其中的10个分类,每个分类6500条数据。类别如下:

体育, 财经, 房产, 家居, 教育, 科技, 时尚, 时政, 游戏, 娱乐

cnews_loader.py为数据的预处理文件。

  • read_file(): 读取文件数据;
  • build_vocab(): 构建词汇表,使用字符级的表示,这一函数会将词汇表存储下来,避免每一次重复处理;
  • read_vocab(): 读取上一步存储的词汇表,转换为{词:id}表示;
  • read_category(): 将分类目录固定,转换为{类别: id}表示;
  • to_words(): 将一条由id表示的数据重新转换为文字;
  • process_file(): 将数据集从文字转换为固定长度的id序列表示;
  • batch_iter(): 为神经网络的训练准备经过shuffle的批次的数据。

textCNN的模型和可配置的参数,在cnn_model.py中。

  1. from __future__ import print_function
  2. import os
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. import numpy as np
  7. import tensorflow as tf
  8. from sklearn import metrics
  9. from cnn_model import TCNNConfig, TextCNN
  10. from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
  1. base_dir = 'cnews'
  2. train_dir = os.path.join(base_dir, 'cnews.train.txt')
  3. test_dir = os.path.join(base_dir, 'cnews.test.txt')
  4. val_dir = os.path.join(base_dir, 'cnews.val.txt')
  5. vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
  6. save_dir = 'checkpoints/textcnn'
  7. save_path = os.path.join(save_dir, 'best_validation') # 最佳验证结果保存路径
  8. def get_time_dif(start_time):
  9. """获取已使用时间"""
  10. end_time = time.time()
  11. time_dif = end_time - start_time
  12. return timedelta(seconds=int(round(time_dif)))
  13. def feed_data(x_batch, y_batch, keep_prob):
  14. feed_dict = {
  15. model.input_x: x_batch,
  16. model.input_y: y_batch,
  17. model.keep_prob: keep_prob
  18. }
  19. return feed_dict
  20. def evaluate(sess, x_, y_):
  21. """评估在某一数据上的准确率和损失"""
  22. data_len = len(x_)
  23. batch_eval = batch_iter(x_, y_, 128)
  24. total_loss = 0.0
  25. total_acc = 0.0
  26. for x_batch, y_batch in batch_eval:
  27. batch_len = len(x_batch)
  28. feed_dict = feed_data(x_batch, y_batch, 1.0)
  29. loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
  30. total_loss += loss * batch_len
  31. total_acc += acc * batch_len
  32. return total_loss / data_len, total_acc / data_len
  33. def train():
  34. print("Configuring TensorBoard and Saver...")
  35. # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
  36. tensorboard_dir = 'tensorboard/textcnn'
  37. if not os.path.exists(tensorboard_dir):
  38. os.makedirs(tensorboard_dir)
  39. tf.summary.scalar("loss", model.loss)
  40. tf.summary.scalar("accuracy", model.acc)
  41. merged_summary = tf.summary.merge_all()
  42. writer = tf.summary.FileWriter(tensorboard_dir)
  43. # 配置 Saver
  44. saver = tf.train.Saver()
  45. if not os.path.exists(save_dir):
  46. os.makedirs(save_dir)
  47. print("Loading training and validation data...")
  48. # 载入训练集与验证集
  49. start_time = time.time()
  50. x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
  51. x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
  52. time_dif = get_time_dif(start_time)
  53. print("Time usage:", time_dif)
  54. # 创建session
  55. session = tf.Session()
  56. session.run(tf.global_variables_initializer())
  57. writer.add_graph(session.graph)
  58. print('Training and evaluating...')
  59. start_time = time.time()
  60. total_batch = 0 # 总批次
  61. best_acc_val = 0.0 # 最佳验证集准确率
  62. last_improved = 0 # 记录上一次提升批次
  63. require_improvement = 1000 # 如果超过1000轮未提升,提前结束训练
  64. flag = False
  65. for epoch in range(config.num_epochs):
  66. print('Epoch:', epoch + 1)
  67. batch_train = batch_iter(x_train, y_train, config.batch_size)
  68. for x_batch, y_batch in batch_train:
  69. feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
  70. if total_batch % config.save_per_batch == 0:
  71. # 每多少轮次将训练结果写入tensorboard scalar
  72. s = session.run(merged_summary, feed_dict=feed_dict)
  73. writer.add_summary(s, total_batch)
  74. if total_batch % config.print_per_batch == 0:
  75. # 每多少轮次输出在训练集和验证集上的性能
  76. feed_dict[model.keep_prob] = 1.0
  77. loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
  78. loss_val, acc_val = evaluate(session, x_val, y_val) # todo
  79. if acc_val > best_acc_val:
  80. # 保存最好结果
  81. best_acc_val = acc_val
  82. last_improved = total_batch
  83. saver.save(sess=session, save_path=save_path)
  84. improved_str = '*'
  85. else:
  86. improved_str = ''
  87. time_dif = get_time_dif(start_time)
  88. msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
  89. + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
  90. print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
  91. feed_dict[model.keep_prob] = config.dropout_keep_prob
  92. session.run(model.optim, feed_dict=feed_dict) # 运行优化
  93. total_batch += 1
  94. if total_batch - last_improved > require_improvement:
  95. # 验证集正确率长期不提升,提前结束训练
  96. print("No optimization for a long time, auto-stopping...")
  97. flag = True
  98. break # 跳出循环
  99. if flag: # 同上
  100. break
  101. def test():
  102. print("Loading test data...")
  103. start_time = time.time()
  104. x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
  105. session = tf.Session()
  106. session.run(tf.global_variables_initializer())
  107. saver = tf.train.Saver()
  108. saver.restore(sess=session, save_path=save_path) # 读取保存的模型
  109. print('Testing...')
  110. loss_test, acc_test = evaluate(session, x_test, y_test)
  111. msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
  112. print(msg.format(loss_test, acc_test))
  113. batch_size = 128
  114. data_len = len(x_test)
  115. num_batch = int((data_len - 1) / batch_size) + 1
  116. y_test_cls = np.argmax(y_test, 1)
  117. y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存预测结果
  118. for i in range(num_batch): # 逐批次处理
  119. start_id = i * batch_size
  120. end_id = min((i + 1) * batch_size, data_len)
  121. feed_dict = {
  122. model.input_x: x_test[start_id:end_id],
  123. model.keep_prob: 1.0
  124. }
  125. y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
  126. # 评估
  127. print("Precision, Recall and F1-Score...")
  128. print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
  129. # 混淆矩阵
  130. print("Confusion Matrix...")
  131. cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
  132. print(cm)
  133. time_dif = get_time_dif(start_time)
  134. print("Time usage:", time_dif)

模型训练

sys.argv 输出运行文件名

  1. type_ = 'train'
  2. print('Configuring CNN model...')
  3. config = TCNNConfig()
  4. if not os.path.exists(vocab_dir): # 如果不存在词汇表,重建
  5. build_vocab(train_dir, vocab_dir, config.vocab_size)
  6. categories, cat_to_id = read_category()
  7. words, word_to_id = read_vocab(vocab_dir)
  8. config.vocab_size = len(words)
  9. model = TextCNN(config)
  10. if type_ == 'train':
  11. train()
  12. else:
  13. test()

模型测试

test()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/729943
推荐阅读
相关标签
  

闽ICP备14008679号