当前位置:   article > 正文

基于CRF的实体识别

基于crf的实体识别

  实体识别在实际生活中具有很大的 ,如识别一段文字中的人名,从而为构建知识图谱具有很重要的基础作用。常见的实体识别主要包括人名、地名、时间和组织机构;也可以根据业务的需求构建相应的实体,本文以CRF模型为理论支撑,利用人民日报的语料进行人名、地名、时间以及组织机构识别,从而对一段冗长的信息中提取出所需要的实体信息。

  Crf的理论可以参考其资料进行阅读,本文主要是用于学习笔记以及后续其他业务方向的需求做一个技术的基础实践

  1. import re
  2. import sklearn_crfsuite#pip install python-crfsuite
  3. from sklearn_crfsuite import metrics
  4. from sklearn.externals import joblib
  5. import pycrfsuite
  6. """初始化"""
  7. train_corpus_path = "D:\workspace\project\\NLPcase\\ner\\data\\199801.txt"
  8. process_corpus_path = "D:\workspace\project\\NLPcase\\ner\\data//result-rmrb.txt"
  9. _maps = {u't': u'T',u'nr': u'PER', u'ns': u'ORG',u'nt': u'LOC'}
  10. def read_corpus_from_file(file_path):
  11. """读取语料"""
  12. f = open(train_corpus_path, 'r')#,encoding='utf-8'
  13. lines = f.readlines()
  14. f.close()
  15. return lines
  16. def write_corpus_to_file(data, file_path):
  17. """写语料"""
  18. f = open(file_path, 'wb')
  19. f.write(data)
  20. f.close()
  21. def q_to_b(q_str):
  22. """全角转半角"""
  23. b_str = ""
  24. for uchar in q_str:
  25. inside_code = ord(uchar)
  26. if inside_code == 12288: # 全角空格直接转换
  27. inside_code = 32
  28. elif 65374 >= inside_code >= 65281: # 全角字符(除空格)根据关系转化
  29. inside_code -= 65248
  30. b_str += chr(inside_code)
  31. return b_str
  32. def b_to_q(b_str):
  33. """半角转全角"""
  34. q_str = ""
  35. for uchar in b_str:
  36. inside_code = ord(uchar)
  37. if inside_code == 32: # 半角空格直接转化
  38. inside_code = 12288
  39. elif 126 >= inside_code >= 32: # 半角字符(除空格)根据关系转化
  40. inside_code += 65248
  41. q_str += chr(inside_code)
  42. return q_str
  43. def pre_process():
  44. """语料预处理 """
  45. lines = read_corpus_from_file(train_corpus_path)
  46. new_lines = []
  47. flag = 0
  48. for line in lines:
  49. flag +=1
  50. words = q_to_b(line.strip()).split(u' ')
  51. pro_words = process_t(words)
  52. pro_words = process_nr(pro_words)
  53. pro_words = process_k(pro_words)
  54. new_lines.append(' '.join(pro_words[1:]))
  55. if flag==100:
  56. break
  57. write_corpus_to_file(data='\n'.join(new_lines).encode('utf-8'), file_path=process_corpus_path)
  58. def process_k( words):
  59. """处理大粒度分词,合并语料库中括号中的大粒度分词,类似:[巴萨/n 俱乐部/n]nt """
  60. pro_words = []
  61. index = 0
  62. temp = u''
  63. while True:
  64. word = words[index] if index < len(words) else u''
  65. if u'[' in word:
  66. temp += re.sub(pattern=u'/[a-zA-Z]*', repl=u'', string=word.replace(u'[', u''))
  67. elif u']' in word:
  68. w = word.split(u']')
  69. temp += re.sub(pattern=u'/[a-zA-Z]*', repl=u'', string=w[0])
  70. pro_words.append(temp+u'/'+w[1])
  71. temp = u''
  72. elif temp:
  73. temp += re.sub(pattern=u'/[a-zA-Z]*', repl=u'', string=word)
  74. elif word:
  75. pro_words.append(word)
  76. else:
  77. break
  78. index += 1
  79. return pro_words
  80. def process_nr( words):
  81. """ 处理姓名,合并语料库分开标注的姓和名,类似:温/nr 家宝/nr"""
  82. pro_words = []
  83. index = 0
  84. while True:
  85. word = words[index] if index < len(words) else u''
  86. if u'/nr' in word:
  87. next_index = index + 1
  88. if next_index < len(words) and u'/nr' in words[next_index]:
  89. pro_words.append(word.replace(u'/nr', u'') + words[next_index])
  90. index = next_index
  91. else:
  92. pro_words.append(word)
  93. elif word:
  94. pro_words.append(word)
  95. else:
  96. break
  97. index += 1
  98. return pro_words
  99. def process_t( words):
  100. """处理时间,合并语料库分开标注的时间词,类似: (/w 一九九七年/t 十二月/t 三十一日/t )/w """
  101. pro_words = []
  102. index = 0
  103. temp = u''
  104. while True:
  105. word = words[index] if index < len(words) else u''
  106. if u'/t' in word:
  107. temp = temp.replace(u'/t', u'') + word
  108. elif temp:
  109. pro_words.append(temp)
  110. pro_words.append(word)
  111. temp = u''
  112. elif word:
  113. pro_words.append(word)
  114. else:
  115. break
  116. index += 1
  117. return pro_words
  118. def pos_to_tag( p):
  119. """由词性提取标签"""
  120. t = _maps.get(p, None)
  121. return t if t else u'O'
  122. def tag_perform( tag, index):
  123. """标签使用BIO模式"""
  124. if index == 0 and tag != u'O':
  125. return u'B_{}'.format(tag)
  126. elif tag != u'O':
  127. return u'I_{}'.format(tag)
  128. else:
  129. return tag
  130. def pos_perform( pos):
  131. """去除词性携带的标签先验知识"""
  132. if pos in _maps.keys() and pos != u't':
  133. return u'n'
  134. else:
  135. return pos
  136. def initialize():
  137. """初始化 """
  138. lines = read_corpus_from_file(process_corpus_path)
  139. words_list = [line.strip().split(' ') for line in lines if line.strip()]
  140. del lines
  141. # init_sequence(words_list)
  142. return init_sequence(words_list)
  143. def init_sequence(words_list):
  144. """初始化字序列、词性序列、标记序列 """
  145. words_seq = [[word.split(u'/')[0] for word in words] for words in words_list]
  146. pos_seq = [[word.split(u'/')[1] for word in words] for words in words_list]
  147. tag_seq = [[pos_to_tag(p) for p in pos] for pos in pos_seq]
  148. pos_seq = [[[pos_seq[index][i] for _ in range(len(words_seq[index][i]))]
  149. for i in range(len(pos_seq[index]))] for index in range(len(pos_seq))]
  150. tag_seq = [[[tag_perform(tag_seq[index][i], w) for w in range(len(words_seq[index][i]))]
  151. for i in range(len(tag_seq[index]))] for index in range(len(tag_seq))]
  152. pos_seq = [[u'un']+[pos_perform(p) for pos in pos_seq for p in pos]+[u'un'] for pos_seq in pos_seq]
  153. tag_seq = [[t for tag in tag_seq for t in tag] for tag_seq in tag_seq]
  154. word_seq = [[u'<BOS>']+[w for word in word_seq for w in word]+[u'<EOS>'] for word_seq in words_seq]
  155. return pos_seq,tag_seq,word_seq
  156. pre_process()
  157. pos_seq,tag_seq,word_seq = initialize()
  158. def extract_feature( word_grams):
  159. """特征选取"""
  160. features, feature_list = [], []
  161. for index in range(len(word_grams)):
  162. for i in range(len(word_grams[index])):
  163. word_gram = word_grams[index][i]
  164. feature = {u'w-1': word_gram[0], u'w': word_gram[1], u'w+1': word_gram[2],
  165. u'w-1:w': word_gram[0]+word_gram[1], u'w:w+1': word_gram[1]+word_gram[2],
  166. # u'p-1': self.pos_seq[index][i], u'p': self.pos_seq[index][i+1],
  167. # u'p+1': self.pos_seq[index][i+2],
  168. # u'p-1:p': self.pos_seq[index][i]+self.pos_seq[index][i+1],
  169. # u'p:p+1': self.pos_seq[index][i+1]+self.pos_seq[index][i+2],
  170. u'bias': 1.0}
  171. feature_list.append(feature)
  172. features.append(feature_list)
  173. feature_list = []
  174. return features
  175. def segment_by_window( words_list=None, window=3):
  176. """窗口切分"""
  177. words = []
  178. begin, end = 0, window
  179. for _ in range(1, len(words_list)):
  180. if end > len(words_list): break
  181. words.append(words_list[begin:end])
  182. begin = begin + 1
  183. end = end + 1
  184. return words
  185. def generator():
  186. """训练数据"""
  187. word_grams = [segment_by_window(word_list) for word_list in word_seq]
  188. features = extract_feature(word_grams)
  189. return features, tag_seq
  190. #------------------# 最为主要是构造每个单词的feature与观测序列tag对应-------------------
  191. '''初始化参数'''
  192. algorithm = 'lbfgs'
  193. c1 = "0.1"
  194. c2 = "0.1"
  195. max_iterations = 100
  196. model_path = "D:\workspace\project\\NLPcase\\ner\\model\\model.pkl"
  197. model = sklearn_crfsuite.CRF(algorithm=algorithm, c1=c1, c2=c2,
  198. max_iterations=max_iterations, all_possible_transitions=True)
  199. def save_model(model,model_path):
  200. """保存模型"""
  201. joblib.dump(model, model_path)
  202. def load_model(model_path):
  203. """保存模型"""
  204. return joblib.load(model_path)
  205. # 对模型进行训练
  206. def train(model_path):
  207. x,y = generator()
  208. x_train, y_train = x[500:], y[500:]
  209. x_test, y_test = x[:500], y[:500]
  210. model.fit(x_train, y_train)
  211. labels = list(model.classes_)
  212. labels.remove('O')
  213. y_predict = model.predict(x_test)
  214. metrics.flat_f1_score(y_test, y_predict, average='weighted', labels=labels)
  215. sorted_labels = sorted(labels, key=lambda name: (name[1:], name[0]))
  216. save_model(model,model_path)
  217. def predict(sent):
  218. model = load_model(model_path)
  219. u_sent = q_to_b(sent)
  220. word_lists = [[u'<BOS>']+[c for c in u_sent]+[u'<EOS>']]
  221. word_grams = [segment_by_window(word_list) for word_list in word_lists]
  222. features = extract_feature(word_grams)
  223. y_predict = model.predict(features)
  224. entity = u''
  225. for index in range(len(y_predict[0])):
  226. if y_predict[0][index] != u'O':
  227. if index>0 and y_predict[0][index][-1] != y_predict[0][index-1][-1]:
  228. entity += u' '
  229. entity += u_sent[index]
  230. elif entity[-1] != u' ':
  231. entity += u' '
  232. return entity

参考资料

https://blog.csdn.net/leitouguan8655/article/details/83382412

https://blog.csdn.net/lhxsir/article/details/83387240

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/961357
推荐阅读
相关标签
  

闽ICP备14008679号