赞
踩
使用LSTM构建聊天机器人,无论输入是什么,就算输入完全不符合语法,模型都会给出一个输出,显然,这个输出不是我们想要的,如何识别模型输出是不是我们想要的?因此我们需要一种评估指标,评估模型输出的置信度。评估LSTM模型的置信度本质上是判断输入与模型输出是否属于训练语料集之内,因为LSTM模型是在语料集的输入与标签之间建立了映射关系,对于训练语料集之外的输入,LSTM模型输出是随机的。因此,可以通过判断输入与LSTM模型的输出是否属于训练语料集之内来评估LSTM模型的置信度。
把训练语料集的提问分词:key1,key2,…keyn,做为字典的关键字,对应的应答列表为字典的值:[answer1,answer2,… answerk],相同的关键字加入同一应答列表,如下表:
调用LSTM模型做预测时,把输入做分词,得到:key1,key2,…keyn,分别查找关系字典,得到对应的应答列表,对比应答列表的answer与LSTM模型输出,如果命中则计数器count加1。如果应答列表为空则为不命中。置信度confidence用下式计算:
Confidence = count / key size
import jieba import pickle class RelevanceChat(): def __init__(self,topk=5): self.topk = topk self.fited = False def fit(self,x_data,y_data,ws_decode): self.dict = {} high_fw_max = int(len(x_data) * 0.6) for ask, answer in zip(x_data, y_data): ask_str = ''.join(ask) if len(ask_str) == 0: continue top_key = jieba.lcut(ask_str) #print("top key:", top_key) y_code = ws_decode.transform(answer)[0] key_set = set(top_key) for key in key_set: rel_list = [] if key in self.dict: rel_list = self.dict[key] if rel_list[0]==0: continue elif len(rel_list)>=high_fw_max: print("key list over:", key,"ask:", ask_str) self.dict[key] = [0] continue rel_list.append(y_code) self.dict[key] = rel_list dict_items = self.dict.items() print("size:",len(self.dict)) #print("dict:", dict_items) self.fited = True def relevance(self,ask,answer): assert self.fited, "RelevanceChat 尚未进行 fit 操作" top_key = jieba.lcut(''.join(ask)) #print("top key:", top_key) key_set = set(top_key) key_size = len(key_set) if key_size == 0: return 0.0 rel_num = 0 high_fw = 0 for key in key_set: rel_list = self.dict.get(key) if rel_list is not None: if rel_list[0] == 0: high_fw += 1 elif answer in rel_list: rel_num += 1 if rel_num == 0: relv_val = float(high_fw)/key_size else: relv_val = float(rel_num)/(key_size - high_fw) return relv_val def test(): x_data, y_data = pickle.load(open('pkl/chatbot.pkl', 'rb')) ws_decode = pickle.load(open('pkl/ws_decode.pkl', 'rb')) relv = RelevanceChat(5) relv.fit(x_data,y_data,ws_decode) count = 0 for ask,answer in zip(x_data,y_data): decode = ws_decode.transform(answer)[0] relv_val = relv.relevance(ask,decode) if relv_val<0.7: print("rel:", relv_val) print("ask:",''.join(ask)) print("answer:", ''.join(answer),end='\n\n') count += 1 print("same dialogue Confidence<0.7 count:", count) count = 0 for i,answer in enumerate(y_data): decode = ws_decode.transform(answer)[0] for j,ask in enumerate(x_data): if i==j: continue relv_val = relv.relevance(ask,decode) if relv_val>0.7: #print("rel:", relv_val) #print("ask:",''.join(ask)) #print("answer:", ''.join(answer),end='\n\n') count += 1 print("different dialogue Confidence<0.7 count:",count) if __name__ == '__main__': test()
使用129条对话的语料集进行测试,结果见下图。语料集内对话置信度小于0.7的条数为0,不同对话间置信度大于0.7的有61,误报率:61/(129*128)=0.37%
本作品采用知识共享署名 4.0 国际许可协议进行许可。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。