当前位置:   article > 正文

【人工智能笔记】第一节:基于Keras的seq2seq聊天机器人实现_python用keras会话机器人

python用keras会话机器人

源码:https://github.com/tfwcn/AI

 

Word.txt为用到的字符集,本文用的只包含训练集里的字。

ai.txt为训练素材,格式:问题1\t回答1\n问题2\t回答2\n

 

LSTM原理图:

 σ代表:sigmoid函数

训练过程图:

 

代码如下:

  1. import keras as K
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import math
  5. import os
  6. batch_size = 50 # Batch size for training. 训练批次大小
  7. epochs = 1000 # Number of epochs to train for. 训练多少回
  8. latent_dim = 128 # Latent dimensionality of the encoding space. 隐藏神经元数量
  9. num_samples = 10000 # Number of samples to train on. 训练数量
  10. max_encoder_seq_length = 256 # 句子最大长度
  11. word_file = open('word.txt', 'r', encoding='UTF-8')
  12. alphabet = word_file.read() # 2500
  13. # alphabet += 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' # 英文数字
  14. # alphabet += ',./;\'[]\\-=`<>?:"{+}|_)(*&^%$#@!~` ' # 标点
  15. # alphabet += ',。《》?;‘’:“”【】—()…¥!·' # 中文标点
  16. # alphabet += '\t\n' # 开头结束标志
  17. word_file.close()
  18. print('word', len(alphabet), alphabet)
  19. # 训练数据集
  20. train_file = open('ai.txt', 'r', encoding='UTF-8')
  21. sentences = train_file.read().split('\n')
  22. train_file.close()
  23. question_texts = []
  24. answer_texts = []
  25. for senterce in sentences:
  26. if len(senterce) == 0:
  27. continue
  28. # 补全缺失文字,需重新运行
  29. for t, char in enumerate(senterce):
  30. if alphabet.find(char) == -1:
  31. f2 = open('word.txt', 'w', encoding='utf-8')
  32. f2.truncate() # 清空文件
  33. alphabet += char
  34. f2.write(alphabet)
  35. f2.close()
  36. print('senterce', senterce.split('\t'))
  37. question_text, answer_text = senterce.split('\t')
  38. # \t 作为开头标识
  39. # \n 作为结尾标识
  40. question_text = '\t' + question_text + '\n'
  41. answer_text = '\t' + answer_text + '\n'
  42. # question_text = question_text.ljust(max_encoder_seq_length, '\0')
  43. # answer_text = answer_text.ljust(max_encoder_seq_length, '\0')
  44. question_texts.append(question_text)
  45. answer_texts.append(answer_text)
  46. # print('question_texts', question_texts)
  47. # print('answer_texts', answer_texts)
  48. # 字符与序号对应的字典
  49. char_to_int = dict((c, i) for i, c in enumerate(alphabet))
  50. int_to_char = dict((i, c) for i, c in enumerate(alphabet))
  51. # print('char_to_int', char_to_int)
  52. # print('int_to_char', int_to_char)
  53. # 编码器字符数量
  54. num_encoder_tokens = len(alphabet)
  55. # 解码器字符数量
  56. num_decoder_tokens = len(alphabet)
  57. # 样本数
  58. print('Number of samples:', len(question_texts))
  59. # 输入
  60. encoder_input_data = np.zeros(
  61. (len(question_texts), max_encoder_seq_length, num_encoder_tokens),
  62. dtype='float32')
  63. # 输出
  64. decoder_input_data = np.zeros(
  65. (len(question_texts), max_encoder_seq_length, num_decoder_tokens),
  66. dtype='float32')
  67. # 下一个时间点的输出
  68. decoder_target_data = np.zeros(
  69. (len(question_texts), max_encoder_seq_length, num_decoder_tokens),
  70. dtype='float32')
  71. # enumerate返回下标与元素,zip把两个列表打包成一个个元组组成的列表
  72. # 下面循环生成训练数据,转one hot
  73. for i, (input_text, target_text) in enumerate(zip(question_texts, answer_texts)):
  74. # print('input_text', input_text)
  75. # print('target_text', target_text)
  76. for t, char in enumerate(input_text):
  77. encoder_input_data[i, t, char_to_int[char]] = 1.
  78. for t, char in enumerate(target_text):
  79. # decoder_target_data is ahead of decoder_input_data by one timestep
  80. decoder_input_data[i, t, char_to_int[char]] = 1.
  81. # 翻译时下一个时间点的输入数据
  82. if t > 0:
  83. # decoder_target_data will be ahead by one timestep
  84. # and will not include the start character.
  85. decoder_target_data[i, t-1, char_to_int[char]] = 1.
  86. print('encoder_input_data', len(encoder_input_data))
  87. print('decoder_input_data', len(decoder_input_data))
  88. # ==================编码器=====================
  89. # Define an input sequence and process it.
  90. # 输入一句话
  91. encoder_inputs = K.Input(shape=(None, num_encoder_tokens))
  92. # return_state返回状态,用于状态保持
  93. encoder = K.layers.LSTM(latent_dim, return_sequences=True,
  94. return_state=True, activation=K.activations.tanh)
  95. encoder2 = K.layers.LSTM(latent_dim, return_sequences=False,
  96. return_state=True, activation=K.activations.tanh)
  97. encoder_outputs, state_h, state_c = encoder(encoder_inputs)
  98. encoder_outputs2, state_h2, state_c2 = encoder2(encoder_outputs)
  99. # We discard `encoder_outputs` and only keep the states.
  100. encoder_states = [state_h, state_c]
  101. encoder_states2 = [state_h2, state_c2]
  102. # ==================编码器 end=====================
  103. # ==================解码器=====================
  104. # Set up the decoder, using `encoder_states` as initial state.
  105. # 预测正确答案作为输入
  106. decoder_inputs = K.Input(shape=(None, num_decoder_tokens))
  107. # We set up our decoder to return full output sequences,
  108. # and to return internal states as well. We don't use the
  109. # return states in the training model, but we will use them in inference.
  110. # return_sequences返回完整序列
  111. decoder_lstm = K.layers.LSTM(
  112. latent_dim, return_sequences=True, return_state=True, activation=K.activations.tanh)
  113. decoder_lstm2 = K.layers.LSTM(
  114. latent_dim, return_sequences=True, return_state=True, activation=K.activations.tanh)
  115. decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
  116. initial_state=encoder_states)
  117. decoder_outputs2, _, _ = decoder_lstm2(decoder_outputs,
  118. initial_state=encoder_states2)
  119. decoder_dense = K.layers.Dense(
  120. num_decoder_tokens, activation=K.activations.softmax)
  121. # 输出值,真正答案
  122. decoder_outputs = decoder_dense(decoder_outputs2)
  123. # ==================解码器 end=====================
  124. # 编码 解码
  125. # \t h i \n \t 你 好 \n
  126. # LSTM LSTM
  127. # 你 好 \n
  128. # Define the model that will turn
  129. # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
  130. model = K.Model([encoder_inputs, decoder_inputs], decoder_outputs)
  131. if os.path.exists('s2s.h5'):
  132. print('加载模型')
  133. model.load_weights('s2s.h5')
  134. # Run training
  135. # 训练
  136. # encoder_input_data:输入要翻译的语句
  137. # decoder_input_data:输入解码器的结果\t开头
  138. # decoder_target_data:真正的翻译结果
  139. model.compile(K.optimizers.RMSprop(),
  140. loss=[K.losses.categorical_crossentropy],
  141. metrics=[K.metrics.categorical_crossentropy])
  142. # model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
  143. # batch_size=batch_size,
  144. # epochs=epochs,
  145. # validation_split=0.2)
  146. model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
  147. batch_size=batch_size,
  148. epochs=epochs)
  149. # Save model
  150. model.save_weights('s2s.h5')
  151. # Next: inference mode (sampling). 下一步,推理模式(抽样),识别
  152. # Here's the drill:
  153. # 1) encode input and retrieve initial decoder state
  154. # 2) run one step of decoder with this initial state
  155. # and a "start of sequence" token as target.
  156. # Output will be the next target token
  157. # 3) Repeat with the current target token and current states
  158. # Define sampling models
  159. # 编码模型,encoder_states
  160. encoder_model = K.Model(encoder_inputs, encoder_states + encoder_states2)
  161. # 解码模型
  162. # 状态输入
  163. decoder_state_input_h = K.Input(shape=(latent_dim,))
  164. decoder_state_input_c = K.Input(shape=(latent_dim,))
  165. decoder_state_input_h2 = K.Input(shape=(latent_dim,))
  166. decoder_state_input_c2 = K.Input(shape=(latent_dim,))
  167. decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
  168. decoder_states_inputs2 = [decoder_state_input_h2, decoder_state_input_c2]
  169. # 训练后的LSTM,
  170. decoder_outputs, state_h, state_c = decoder_lstm(
  171. decoder_inputs, initial_state=decoder_states_inputs)
  172. decoder_states = [state_h, state_c]
  173. decoder_outputs2, state_h2, state_c2 = decoder_lstm2(
  174. decoder_outputs, initial_state=decoder_states_inputs2)
  175. decoder_states2 = [state_h2, state_c2]
  176. decoder_outputs = decoder_dense(decoder_outputs2)
  177. # 输入[decoder_inputs, decoder_state_input_h, decoder_state_input_c]
  178. # 输出[decoder_outputs, state_h, state_c]
  179. decoder_model = K.Model(
  180. [decoder_inputs] + decoder_states_inputs + decoder_states_inputs2,
  181. [decoder_outputs] + decoder_states + decoder_states2)
  182. def decode_sequence(input_seq):
  183. # Encode the input as state vectors.
  184. # 编码,抽象概念
  185. states_value = encoder_model.predict(input_seq)
  186. # Generate empty target sequence of length 1.
  187. target_seq = np.zeros((1, 1, num_decoder_tokens))
  188. # Populate the first character of target sequence with the start character.
  189. target_seq[0, 0, char_to_int['\t']] = 1.
  190. # Sampling loop for a batch of sequences
  191. # (to simplify, here we assume a batch of size 1).
  192. stop_condition = False
  193. decoded_sentence = ''
  194. while not stop_condition:
  195. output_tokens, h, c, h2, c2 = decoder_model.predict(
  196. [target_seq] + states_value)
  197. # 对应字符下标,把预测出的字符拼成字符串
  198. # Sample a token
  199. sampled_token_index = np.argmax(output_tokens[0, -1, :])
  200. sampled_char = int_to_char[sampled_token_index]
  201. decoded_sentence += sampled_char
  202. # 句子结束
  203. # Exit condition: either hit max length
  204. # or find stop character.
  205. if (sampled_char == '\n' or
  206. len(decoded_sentence) > max_encoder_seq_length):
  207. stop_condition = True
  208. # Update the target sequence (of length 1).
  209. # 当前字符,传递到下一次预测
  210. target_seq = np.zeros((1, 1, num_decoder_tokens))
  211. target_seq[0, 0, sampled_token_index] = 1.
  212. # Update states
  213. # 当前状态,传递到下一次预测
  214. states_value = [h, c, h2, c2]
  215. return decoded_sentence
  216. for seq_index in range(10):
  217. # Take one sequence (part of the training set)
  218. # for trying out decoding.
  219. input_seq = encoder_input_data[seq_index: seq_index + 1]
  220. decoded_sentence = decode_sequence(input_seq)
  221. print('-')
  222. print('Input sentence:', question_texts[seq_index])
  223. print('Decoded sentence:', decoded_sentence)

执行命令:

python seq2seq.py

 

参考资料:

https://www.jianshu.com/p/9dc9f41f0b29

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/346374?site
推荐阅读
相关标签
  

闽ICP备14008679号