当前位置:   article > 正文

spacy训练模型和更新_nlp.update spacy 3

nlp.update spacy 3

模型训练

  • 初始化模型权重使其变成随机值:调用nlp.begin_training方法;

  • 查看当前权重的表现:调用nlp.update方法

  • 比较预测结果和真实的标签;

  • 计算如何调整权重来改善预测结果;

  • 微调模型权重;

  • 加入实体识别规则,提升识别准确度;

  • 重复上述步骤;

循环训练:

for i in range(10):
	random.shuffle(TRAINING_DATA)
	for batch in spacy.util.minibatch(TRAINING_DATA):
	texts = [text for text, annoation in batch]
	annotations = [annotation for text, annotation in batch]
	nlp.update(texts, annotations)
nlp.to_disk(path_to_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

示例:

 # !/usr/bin/env python
# -*- coding:utf-8 -*-
# 导入所需库
import spacy
from pathlib import Path
from spacy.util import minibatch, compounding
import random
from spacy.training import Example
import json
import time
import uuid
import json
import os
import random
import shutil
from spacy.lang.zh import Chinese

# 临时文件存储集合
file_data = []

# 临时文件夹处理
temp_data = "./tempData"
if os.path.exists(temp_data):
    shutil.rmtree(temp_data)
os.makedirs(temp_data)


#  将数据临时存储
def saveFile():
    path = temp_data + "/" + str(uuid.uuid4()).replace("-", "") + ".json"
    file_data.append(path)
    return open(path, 'a+', encoding="utf-8")


# 加载品牌数据
with open("./pingpai_pref.json", 'r', encoding='utf-8') as load_pin:
    load_pin = json.load(load_pin)
    # load_pin = load_pin[0:1000]

# 加载型号数据
with open("./xinghao_pref.json", 'r', encoding="utf-8") as load_xin:
    load_xin = json.load(load_xin)
    load_xin = load_xin[0:1000]

# 加载分类数据
with open("./fenlei_pref.json", 'r', encoding="utf-8") as load_fen:
    load_fen = json.load(load_fen)
    # load_fen = load_fen[0:1000]

# 描述
with open("./desc.json", 'r', encoding="utf-8") as load_xinDesc:
    load_xinDesc = json.load(load_xinDesc)

TRAIN_DATA = []

brand = "BRAND"
type = "TYPE"
classify = "CLASSIFY"
DESC = "DESC"

# 开始干扰词汇
head = ["我的", "今天", "突然", "偶尔", "好像", "明天", "有时候", "貌似", "不知道", "我有一个", "多个", "怎么 "]
hl = len(head) - 1

# 结束干扰词汇
end = ["然后", "突然就", "后", "就", "导致", "引发", "影响", "后面", "触发", "和", ","]
el = len(end) - 1

# 品牌
brandLen = len(load_pin) - 1
print("品牌", len(load_pin))

# 分类
fenLen = len(load_fen) - 1
print("分类", len(load_fen))

# 型号
xinLen = len(load_xin) - 1
print("型号", len(load_xin))

# 描述
desc_array_len = len(load_xinDesc) - 1
print("描述", len(load_xinDesc))

print("准备品牌、型号、分类数据开始", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

for i in load_pin:
    pingPai = i.strip()
    # 单个
    val = (pingPai, {'entities': [(0, len(pingPai), brand)]})
    TRAIN_DATA.append(val)
    # 多个
    t = load_pin[random.randint(0, brandLen)].strip()
    val = (pingPai + t, {'entities': [(0, len(pingPai), brand),
                                      (len(pingPai), len(pingPai + t), brand)]})
    TRAIN_DATA.append(val)
    # 品牌 - 分类
    cate = load_fen[random.randint(0, fenLen)].strip()
    val = (
        pingPai + cate,
        {'entities': [(0, len(pingPai), brand),
                      (len(pingPai), len(pingPai) + len(cate), classify)]})
    TRAIN_DATA.append(val)
    # 品牌 - 型号
    xin = load_xin[random.randint(0, xinLen)].strip()
    val = (
        pingPai + xin,
        {'entities': [(0, len(pingPai), brand),
                      (len(pingPai), len(pingPai) + len(xin), type)]})
    TRAIN_DATA.append(val)
    # 品牌 - 开始干扰词汇
    t = head[random.randint(0, hl)]
    val = (
        t + pingPai,
        {'entities': [(len(t), len(t) + len(pingPai), brand)]})
    TRAIN_DATA.append(val)
    # 品牌 - 结束干扰词汇
    t = end[random.randint(0, el)]
    val = (
        pingPai + t,
        {'entities': [(0, len(pingPai), brand)]})
    TRAIN_DATA.append(val)

# 型号
print("型号", len(load_xin))
for it in load_xin:
    xinHao = it.strip()
    # 单个
    val = (xinHao, {'entities': [(0, len(xinHao), type)]})
    TRAIN_DATA.append(val)
    # 型号 - 分类
    cate = load_fen[random.randint(0, fenLen)].strip()
    val = (
        xinHao + cate, {'entities': [(0, len(xinHao), type), (len(xinHao), len(xinHao) + len(cate), classify)]})
    TRAIN_DATA.append(val)
    # 型号 - 描述
    desc_t = load_xinDesc[random.randint(0, desc_array_len)]
    val = (
        xinHao + desc_t, {'entities': [(0, len(xinHao), type), (len(xinHao), len(xinHao) + len(desc_t), DESC)]})
    TRAIN_DATA.append(val)
    # 型号 - 开始干扰词汇
    t = head[random.randint(0, hl)]
    val = (
        t + xinHao,
        {'entities': [(len(t), len(t) + len(xinHao), type)]})
    TRAIN_DATA.append(val)
    # 型号 - 结束干扰词汇
    t = end[random.randint(0, el)]
    val = (
        xinHao + t,
        {'entities': [(0, len(xinHao), type)]})
    TRAIN_DATA.append(val)

# 描述
desc_array = []
for s in load_xinDesc:
    desc = s.strip()
    # 单个
    val = (desc, {'entities': [(0, len(desc), DESC)]})
    TRAIN_DATA.append(val)
    desc_array.append(desc)
    # 多个
    t = load_xinDesc[random.randint(0, desc_array_len)].strip()
    val = (desc + t, {'entities': [(0, len(desc), DESC),
                                   (len(desc), len(desc + t), DESC)]})
    TRAIN_DATA.append(val)

# 分类
for it in load_fen:
    classify_t = it.strip()
    # 单个
    val = (classify_t, {'entities': [(0, len(classify_t), classify)]})
    TRAIN_DATA.append(val)
    # 多个
    t = load_fen[random.randint(0, fenLen)].strip()
    val = (classify_t + t, {'entities': [(0, len(classify_t), classify),
                                         (len(classify_t), len(classify_t + t), classify)]})
    TRAIN_DATA.append(val)
    # 分类 - 开始干扰词汇
    t = head[random.randint(0, hl)]
    val = (
        t + classify_t,
        {'entities': [(len(t), len(t) + len(classify_t), classify)]})
    TRAIN_DATA.append(val)
    # 分类 - 结束干扰词汇
    t = end[random.randint(0, el)]
    val = (
        classify_t + t,
        {'entities': [(0, len(classify_t), classify)]})
    TRAIN_DATA.append(val)
    # 分类 - 描述
    t = load_xinDesc[random.randint(0, desc_array_len)]
    t2 = load_xinDesc[random.randint(0, desc_array_len)]
    val = (
        classify_t + t + t2,
        {'entities': [(0, len(classify_t), classify),
                      (len(classify_t), len(classify_t + t), DESC),
                      (len(classify_t + t), len(classify_t + t + t2), DESC)]})
    TRAIN_DATA.append(val)

print("品牌 型号 分类 data size: ", str(len(TRAIN_DATA)))

file = saveFile()
file.write(json.dumps(TRAIN_DATA, indent=4, ensure_ascii=False))
file.close()

# 清空防止内存溢出
TRAIN_DATA = []

print("品牌 - 型号 - 分类 - 描述 正在处理...")
count_three = 0
# 品牌 - 型号 - 分类 - 描述
for i in load_pin:
    rd = random.randint(0, 40)
    if rd >= 2:
        continue
    # 品牌
    pingPai = i.strip()
    pingPai_y = {'entities': [(0, len(pingPai), brand)]}
    for j in load_xin:
        rd = random.randint(0, 30)
        if rd >= 2:
            continue
        xinHao = j.strip()
        result_concat = pingPai
        result_concat += xinHao
        # 分类
        for k in load_fen:
            classify_p = k.strip()

            # 三合一
            # pingPai_copy = json.loads(json.dumps(pingPai_y))
            # pingPai_copy.get("entities").append((len(pingPai), len(result_concat) - 1, type))
            # pingPai_copy.get("entities").append(
            #     (len(result_concat), len(result_concat) + len(classify_p) - 1, classify))
            #
            # result_concat_copy = str(result_concat)
            # result_concat_copy += classify_p
            # res_final = (result_concat_copy, pingPai_copy)
            # TRAIN_DATA.append(res_final)

            # 开始干扰词汇
            rd_hl = random.randint(0, hl)
            hh = head[rd_hl]
            rd_el = random.randint(0, el)
            ee = end[rd_el]
            des_r = random.randint(0, desc_array_len)
            desc_v = desc_array[des_r]

            pingPai_yy = {'entities': [(len(hh), len(hh) + len(pingPai), brand)]}

            pingPai_copy2 = json.loads(json.dumps(pingPai_yy))
            pingPai_copy2.get("entities").append((len(pingPai) + len(hh), len(result_concat) + len(hh), type))
            pingPai_copy2.get("entities").append(
                (len(result_concat) + len(hh), len(result_concat) + len(hh) + len(classify_p), classify))

            result_concat_copy2 = str(result_concat)
            result_concat_copy2 += classify_p
            result_concat_copy2 = hh + result_concat_copy2

            # 单故障描述
            pingPai_copy2.get("entities").append(
                (len(result_concat_copy2), len(result_concat_copy2 + desc_v), DESC))
            result_concat_copy2 = result_concat_copy2 + desc_v

            # 多故障描述
            des_t = random.randint(0, desc_array_len)
            desc_t = desc_array[des_t]
            esult_concat_copy2 = result_concat_copy2 + ee
            pingPai_copy2.get("entities").append(
                (len(result_concat_copy2), len(result_concat_copy2 + desc_t), DESC))
            result_concat_copy2 = result_concat_copy2 + desc_t

            res_final = (result_concat_copy2, pingPai_copy2)
            # print("res_final", res_final)
            TRAIN_DATA.append(res_final)

            count_three = count_three + 1

            if len(TRAIN_DATA) % 10000 == 0:
                file = saveFile()
                file.write(json.dumps(TRAIN_DATA, indent=4, ensure_ascii=False))
                file.close()
                TRAIN_DATA = []


print("品牌 型号 分类 处理结束, data size:", str(count_three))

file = saveFile()
file.write(json.dumps(TRAIN_DATA, indent=4, ensure_ascii=False))
file.close()

# 清空防止内存溢出
TRAIN_DATA = []

output_dir = "../../model"
ruler_dir = "./ruler_model"

nlp = Chinese()
ruler = nlp.add_pipe("entity_ruler", config={"validate": True})

patterns = []
# 品牌 - 型号 - 分类 - 描述
for i in load_pin:
    patterns.append({"label": brand, "pattern": i, "id": brand})
for i in load_xin:
    patterns.append({"label": type, "pattern": i, "id": type})
for i in load_xinDesc:
    patterns.append({"label": DESC, "pattern": i, "id": DESC})
for i in load_fen:
    patterns.append({"label": classify, "pattern": i, "id": classify})

ruler.add_patterns(patterns)
nlp.to_disk(ruler_dir)

#  实体识别
ner = nlp.create_pipe('ner')
nlp.add_pipe('ner')
# 训练次数
n_iter = 1

# 实体标签
# for _, annotations in TRAIN_DATA:
#     for ent in annotations.get('entities'):
#         ner.add_label(ent[2])

ner.add_label(brand)
ner.add_label(type)
ner.add_label(classify)
ner.add_label(DESC)

pipe_exceptions = ["tok2vec", "tagger", "parser", "ner", "entity_ruler"]

# 仅训练我们标注的标签,假如没有则会对所有的标签训练,
# 建议不要对下载的spacy的模型进行训练可能导致下载的语言模型出错,训练一个空白语言模型就好
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
with nlp.disable_pipes(*other_pipes):
    # 模型初始化
    optimizer = nlp.begin_training()
    # 训练次数 次数越多越精确
    print("训练次数", str(n_iter * len(file_data)))
    for itn in range(n_iter):

        for f in file_data:
            with open(f, 'r', encoding="utf-8") as fItem:
                TRAIN_DATA = json.load(fItem)

            # 训练数据每次迭代打乱顺序
            random.shuffle(TRAIN_DATA)
            # 定义损失函数
            losses = {}
            # 分批训练
            example = []
            for text, annotations in TRAIN_DATA:
                # 对数据进行整理成新模型需要的数据
                example.append(Example.from_dict(nlp.make_doc(text), annotations))
            # 训练
            for batch in minibatch(example, size=500):
                # drop。使模型更难记住数据 , 使模型更难记住数据。 例如,辍学意味着每个功能或内部表示具有 1/4 的被丢弃的可能性
                # update 参考地址 https://spacy.io/api/language#update
                # 训练参考文档 https://spacy.io/usage/training#custom-functions
                # https: // spacy.io / usage / rule - based - matching  # entityruler
                nlp.update(batch, drop=0.1, sgd=optimizer, losses=losses)
            print(str(itn), losses)

    # 保存模型
nlp.to_disk(output_dir)
print("Saved model to", output_dir, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

# 合并规则
shutil.copyfile(ruler_dir + "/entity_ruler/patterns.jsonl", output_dir + "/entity_ruler/patterns.jsonl")

print("-----------test----------")
text = "笔记本电脑黑屏"
nlp = spacy.load(output_dir)
doc = nlp(text)
print('ner', [(t.text, t.label_) for t in doc.ents])

text = "小米笔记本电脑黑屏"
nlp = spacy.load(output_dir)
doc = nlp(text)

print('ner', [(t.text, t.label_) for t in doc.ents])

# 临时文件清理
# if os.path.exists(temp_data):
#     shutil.rmtree(temp_data)





  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392

模型更新训练

  • 加载模型权重使其变成随机值:调用nlp.resume_training()方法;

  • 查看当前权重的表现:调用nlp.update方法

  • 比较预测结果和真实的标签;

  • 计算如何调整权重来改善预测结果;

  • 微调模型权重;

  • 重复上述步骤;

示例:

#导入所需库 
#参考网址:https://www.cnblogs.com/Ukiii/p/14709696.html
import spacy
from pathlib import Path
from spacy.util import minibatch, compounding
import random
from spacy.training import Example

TRAIN_DATA = [
    ('谁是Shaka Khan?', {
        'entities': [(2, 12, 'PERSON')]  
        ###实体标注的索引从0开始17是最后一字符的索引+1
    }),
    ('I like London and Berlin.', {
        'entities': [(6, 13, 'LOC'), (17, 24, 'LOC')]
    }),
('我的华为p30开不了机怎么办', {
        'entities': [(4, 7, 'TYPE')]
    }),
('p30', {
        'entities': [(0, 3, 'TYPE')]
    }),
('华但是一个组织', {
        'entities': [(0, 1, 'ORG')]
    })
]

output_dir ="./model"

nlp=spacy.load('zh_core_web_lg')
ner=nlp.get_pipe("ner")
n_iter=100

# add labels
for _, annotations in TRAIN_DATA:
    for ent in annotations.get('entities'):
        ner.add_label(ent[2])

# 仅训练我们标注的标签,假如没有则会对所有的标签训练,
#建议不要对下载的spacy的模型进行训练可能导致下载的语言模型出错,训练一个空白语言模型就好
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
with nlp.disable_pipes(*other_pipes): 

        ##重新训练模型
  optimizer = nlp.resume_training()
  for itn in range(n_iter):

##训练数据每次迭代打乱顺序
      random.shuffle(TRAIN_DATA)
##定义损失函数
      losses = {}
      for text, annotations in TRAIN_DATA:
 ##对数据进行整理成新模型需要的数据
          example = Example.from_dict(nlp.make_doc(text), annotations)   
          print("example:",example)
          nlp.update([example], drop=0.5,  sgd=optimizer,  losses=losses)
      print(losses)

    # 保存模型
nlp.to_disk(output_dir)
print("Saved model to", output_dir)

print("-----------test----------")

text="P30开不了机"
nlp = spacy.load(output_dir)
print("Loading from", output_dir)
doc = nlp(text)
for i in doc.ents:
   print(i.text,i.label_)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

训练数据格式

TRAINING_DATA = [
("...", {"entities": [(0,1, "WEBSITE")]}),
("...", {"entities": [(0,1, "PERSON")]})
]

  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/367926
推荐阅读
相关标签
  

闽ICP备14008679号