当前位置:   article > 正文

tensor 增加一维度_PyTorch搭建聊天机器人(一)词表与数据加载器

使用pycharm开发基于seq2seq模型的简易中文聊天机器人

ecb17875494938654a0f8f3df6508aef.png

国庆无聊逛了逛PyTorch的tutorial,其中有一篇chatbot的搭建蛮有意思的。

Chatbot Tutorial​pytorch.org
de870955910367abdf2c5d14625f46a7.png

作为一个蒟蒻大二,我看了看tutorial涉及到的论文,并且自己按照batch_first=True动手写了写,算是有点收获吧。打算写三四篇文章总结一下技术细节。顺序大概是:数据加载器、网络前向逻辑、训练逻辑、评估逻辑。

使用pycharm编写项目,代码分为四个文件:process.py、neural_network.py、train.py、evaluate.py。

先大致说一下搭建chatbot的思路吧,其实很简单:这里的chatbot是基于带Luong attention机制的seq2seq。研究过NLP的同学应该对seq2seq很熟悉,它可以将任意长度的时序信息映射到任意长度,在基于深度神经网络的机器翻译中使用广泛。

实际上,中文翻译成英文就是训练出一个中文序列到英文序列的映射,而我们的chatbot不就是一个句子到句子的映射吗?在不考虑上下语境的情况下,聊天机器人可以使用seq2seq搭建。如此搭建的聊天机器人对用户输入的语句给出的回复更像是将用户说的话翻译成了用户希望得到的回复。那么假设我们已经对seq2seq很熟悉了,那么只需要使用一条条对话(下面叫dialog或者pair)作为数据,训练这个seq2seq模型就可以得到这个训练集风格的chatbot了。

tutorial使用的数据是Cornell Movie-Dialogs,下载地址。这部分数据的编码格式不是utf-8,如果你对编码转换这部分不感兴趣,可以直接使用笔者仓库中./data中的tsv数据。后面的程序中将会直接使用tsv数据。

笔者仓库链接如下:

LSTM-Kirigaya/chatbot-based-on-seq2seq2​github.com
5cb7995712161dfdfae56677adf32e0e.png

提前说明一下,对话数据集中的每个pair中,我们把第一句话成为input_dialog,后面一句回复的话称为ouput_dialog。

5468ecd619c7bb5add2e3c3af7ad3f14.png

下面完成process.py,这个文件完成词表建立和数据加载器的建立。

说明:下面所有数据组织都是按照batch_first来的,也就是所有torch张量的第一个维度是batch_size

先引入需要的库

  1. from itertools import zip_longest
  2. import random
  3. import torch

构建词表

第一步我们需要构建词表,因为网络中只会传递张量,我们需要通过构建词表将每个单词映射成一个个单词索引(后面成为index),也就是将一句话转化为index序列。

词表中最核心的数据是三个python类型的词典:

  • word2index:单词到其对应的index的映射。
  • index2word:index到其对应的单词的映射。
  • word2count:单词到其在数据集中的总数的映射。

构建词表的逻辑也很简单,只需要遍历数据集,每遇到一个词表中没有的单词,就根据已经添加单词的总数给与这个新的单词一个index,并由此给word2index和index2word两个字典增加新的元素。

程序如下:

  1. # 用来构造字典的类
  2. class vocab(object):
  3. def __init__(self, name, pad_token, sos_token, eos_token, unk_token):
  4. self.name = name
  5. self.pad_token = pad_token
  6. self.sos_token = sos_token
  7. self.eos_token = eos_token
  8. self.unk_token = unk_token
  9. self.trimmed = False # 代表这个词表对象是否经过了剪枝操作
  10. self.word2index = {"PAD" : pad_token, "SOS" : sos_token, "EOS" : eos_token, "UNK" : unk_token}
  11. self.word2count = {"UNK" : 0}
  12. self.index2word = {pad_token : "PAD", sos_token : "SOS", eos_token : "EOS", unk_token : "UNK"}
  13. self.num_words = 4 # 刚开始的四个占位符 pad(0), sos(1), eos(2),unk(3) 代表目前遇到的不同的单词数量
  14. # 向voc中添加一个单词的逻辑
  15. def addWord(self, word):
  16. if word not in self.word2index:
  17. self.word2index[word] = self.num_words
  18. self.word2count[word] = 1
  19. self.index2word[self.num_words] = word
  20. self.num_words += 1
  21. else:
  22. self.word2count[word] += 1
  23. # 向voc中添加一个句子的逻辑
  24. def addSentence(self, sentence):
  25. for word in sentence.split():
  26. self.addWord(word)
  27. # 将词典中词频过低的单词替换为unk_token
  28. # 需要一个代表修剪阈值的参数min_count,词频低于这个参数的单词会被替换为unk_token,相应的词典变量也会做出相应的改变
  29. def trim(self, min_count):
  30. if self.trimmed: # 如果已经裁剪过了,那就直接返回
  31. return
  32. self.trimmed = True
  33. keep_words = []
  34. keep_num = 0
  35. for word, count in self.word2count.items():
  36. if count >= min_count:
  37. keep_num += 1
  38. # 由于后面是通过对keep_word列表中的数据逐一统计,所以需要对count>1的单词重复填入
  39. for _ in range(count):
  40. keep_words.append(word)
  41. print("keep words: {} / {} = {:.4f}".format(
  42. keep_num, self.num_words - 4, keep_num / (self.num_words - 4)
  43. ))
  44. # 重构词表
  45. self.word2index = {"PAD" : self.pad_token, "SOS" : self.sos_token, "EOS" : self.eos_token, "UNK" : self.unk_token}
  46. self.word2count = {}
  47. self.index2word = {self.pad_token : "PAD", self.sos_token : "SOS", self.eos_token : "EOS", self.unk_token : "UNK"}
  48. self.num_words = 4
  49. for word in keep_words:
  50. self.addWord(word)
  51. # 读入数据,统计词频,并返回数据
  52. def load_data(self, path):
  53. pairs = []
  54. for line in open(path, "r", encoding="utf-8"):
  55. try:
  56. input_dialog, output_dialog = line.strip().split("t")
  57. self.addSentence(input_dialog.strip())
  58. self.addSentence(output_dialog.strip())
  59. pairs.append([input_dialog, output_dialog])
  60. except:
  61. pass
  62. return pairs

这个词表类,需要五个参数初始化:name、pad_token、sos_token、eos_token、unk_token。分别为词表的名称、填充词的index、句子开头标识符的index、句子结束标识符的index和未识别单词的index。

主要方法说明如下:

  • __init__:完成词表的初始化。
  • trim:根据min_count对词表进行剪枝。
  • load_data:载入外部tsv数据,完成三个字典的搭建,并返回处理好的pairs。

处理input_dialog和output_dialog

有了词表,我们就可以根据词表把一句话转换成index序列,为此我们通过sentenceToIndex函数完成sentence到index sequence的转换,需要说明的是,为了让后续搭建的网络知道一句话已经结束了,我们需要给每个转换成的index序列的句子添加一个eos_token作为后缀:

  1. # 将一句话转换成id序列(str->list),结尾加上EOS
  2. def sentenceToIndex(sentence, voc):
  3. return [voc.word2index[word] for word in sentence.split()] + [voc.eos_token]

接下来我们需要分别处理input_dialog和output_dialog。

处理input_dialog需要一个batchInput2paddedTensor函数,这个函数接受batch_size句没有处理过的input_dialog文字、将它们转换为index序列、填充pad_token、转换成batch_first=True的torch张量,返回处理好的torch张量和每句话的长度信息。

大致过程还是画张图吧。。。

891a8e511836a7daeb19f69c81fba27d.png

代码如下:

  1. # 将一个batch中的input_dialog转化为有pad填充的tensor,并返回tensor和记录长度的变量
  2. # 返回的tensor是batch_first
  3. def batchInput2paddedTensor(batch, voc):
  4. # 先转换为id序列,但是这个id序列不对齐
  5. batch_index_seqs = [sentenceToIndex(sentence, voc) for sentence in batch]
  6. length_tensor = torch.tensor([len(index_seq) for index_seq in batch_index_seqs])
  7. # 下面填充0(PAD),使得这个batch中的序列对齐
  8. zipped_list = list(zip_longest(*batch_index_seqs, fillvalue=voc.pad_token))
  9. padded_tensor = torch.tensor(zipped_list).t()
  10. return padded_tensor, length_tensor

处理output_dialog与input_dialog差不多,只不过需要多返回一个mask矩阵,所谓mask矩阵,就是将padded_tensor转换成bool类型。这些返回的量在后续的训练中都会使用到。output_dialog的处理如下:

  1. # 将一个batch中的output_dialog转化为有pad填充的tensor,并返回tensor、mask和最大句长
  2. # 返回的tensor是batch_first
  3. def batchOutput2paddedTensor(batch, voc):
  4. # 先转换为id序列,但是这个id序列不对齐
  5. batch_index_seqs = [sentenceToIndex(sentence, voc) for sentence in batch]
  6. max_length = max([len(index_seq) for index_seq in batch_index_seqs])
  7. # 下面填充0(PAD),使得这个batch中的序列对齐
  8. zipped_list = list(zip_longest(*batch_index_seqs, fillvalue=voc.pad_token))
  9. padded_tensor = torch.tensor(zipped_list).t()
  10. # 得到padded_tensor对应的mask
  11. mask = torch.BoolTensor(zipped_list).t()
  12. return padded_tensor, mask, max_length

有了处理pair的函数,我们可以把上面的函数整合成一个数据加载器loader。数据加载器在深度学习中很重要,我们在训练中需要能够不重复的、快速地获取一个batch的格式化数据,这就是loader的功能,惬意舒适(in my dream...)的训练中,一个设计合理而高效的loader是必不可少的。

此处不再解释Python中生成器的概念,为了更加节省内存空间,我们将loader做成一个生成器:

  1. # 获取数据加载器的函数
  2. # 将输入的一个batch的dialog转换成id序列,填充pad,并返回训练可用的id张量和mask
  3. def DataLoader(pairs, voc, batch_size, shuffle=True):
  4. if shuffle:
  5. random.shuffle(pairs)
  6. batch = []
  7. for idx, pair in enumerate(pairs):
  8. batch.append([pair[0], pair[1]])
  9. # 数据数量到达batch_size就yield出去并清空
  10. if len(batch) == batch_size:
  11. # 为了后续的pack_padded_sequence操作,我们需要给这个batch中的数据按照input_dialog的长度排序(降序)
  12. batch.sort(key=lambda x : len(x[0].split()), reverse=True)
  13. input_dialog_batch = []
  14. output_dialog_batch = []
  15. for pair in batch:
  16. input_dialog_batch.append(pair[0])
  17. output_dialog_batch.append(pair[1])
  18. input_tensor, input_length_tensor = batchInput2paddedTensor(input_dialog_batch, voc)
  19. output_tensor, mask, max_length = batchOutput2paddedTensor(output_dialog_batch, voc)
  20. # 清空临时缓冲区
  21. batch = []
  22. yield [
  23. input_tensor, input_length_tensor, output_tensor, mask, max_length
  24. ]

要写的函数差不多写好了,我们可以测试一下:

  1. if __name__ == "__main__":
  2. PAD_token = 0 # 补足句长的pad占位符的index
  3. SOS_token = 1 # 代表一句话开头的占位符的index
  4. EOS_token = 2 # 代表一句话结尾的占位符的index
  5. UNK_token = 3 # 代表不在词典中的字符
  6. BATCH_SIZE = 64 # 一个batch中的对话数量(样本数量)
  7. MAX_LENGTH = 20 # 一个对话中每句话的最大句长
  8. MIN_COUNT = 3 # trim方法的修剪阈值
  9. # 实例化词表
  10. voc = vocab(name="corpus", pad_token=PAD_token, sos_token=SOS_token, eos_token=EOS_token, unk_token=UNK_token)
  11. # 为词表载入数据,统计词频,并得到对话数据
  12. pairs = voc.load_data(path="./data/dialog.tsv")
  13. print("total number of dialogs:", len(pairs))
  14. # 修剪与替换
  15. pairs = trimAndReplace(voc, pairs, MIN_COUNT)
  16. # 获取loader
  17. loader = DataLoader(pairs, voc, batch_size=5)
  18. batch_item_names = ["input_tensor", "input_length_tensor", "output_tensor", "mask", "max_length"]
  19. for batch_index, batch in enumerate(loader):
  20. for name, item in zip(batch_item_names, batch):
  21. print(f"n{name} : {item}")
  22. break

out:

  1. total number of dialogs: 64223
  2. keep words: 7821 / 17999 = 0.4345
  3. Trimmed from 64223 pairs to 58362, 0.9087 of total
  4. input_tensor : tensor([[ 123, 51, 48, 8, 918, 2227, 330, 3068, 7, 2],
  5. [ 302, 303, 102, 38, 3, 71, 3, 7, 2, 0],
  6. [ 158, 3, 7, 2, 0, 0, 0, 0, 0, 0],
  7. [ 188, 7, 2, 0, 0, 0, 0, 0, 0, 0],
  8. [ 563, 5, 2, 0, 0, 0, 0, 0, 0, 0]])
  9. input_length_tensor : tensor([10, 9, 4, 3, 3])
  10. output_tensor : tensor([[3244, 5, 2, 0, 0, 0, 0, 0, 0, 0],
  11. [ 35, 37, 38, 68, 77, 5, 2, 0, 0, 0],
  12. [ 181, 5, 1233, 13, 1233, 13, 1222, 5, 2, 0],
  13. [ 102, 38, 45, 188, 99, 680, 1375, 5, 2, 0],
  14. [ 26, 198, 118, 25, 51, 41, 48, 1597, 5, 2]])
  15. mask : tensor([[ True, True, True, False, False, False, False, False, False, False],
  16. [ True, True, True, True, True, True, True, False, False, False],
  17. [ True, True, True, True, True, True, True, True, True, False],
  18. [ True, True, True, True, True, True, True, True, True, False],
  19. [ True, True, True, True, True, True, True, True, True, True]])
  20. max_length : 10

做好了数据加载器,后面就可以开始构建网络结构了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/379222
推荐阅读
相关标签
  

闽ICP备14008679号