当前位置:   article > 正文

Bert-pytorch-英文文本多分类_bert文本分类 多分类

bert文本分类 多分类

        网上多是Bert中文文本分类居多,之前找了很久才找到一篇参考文章,深知对于小白而言借鉴别人的代码训练模型重点在输入输出及改动参数,在这里说一下我借鉴别人的代码跑出自己的数据集的过程。

        参考的作者是:https://www.bilibili.com/video/BV1DQ4y1U7jG/?spm_id_from=333.337.search-card.all.click&vd_source=fc681f6795f19749927a346ce9af92e8

一,准备Bert模型

bert-base-uncased at main (huggingface.co)   下载如下图所示文件放在一个文件夹下,文件夹命名为 bert-base-uncased(随意,反正后面改路径就行)。

 二,准备数据集

        我用的是AG-news数据集,有四个标签,相当于四分类。数据集格式如下:

         相当于最常见的两列数据集,第一列是文本,第二列为标签,在这里注意处理数据集的时候把'\t'字符全都替换掉,然后最后在text末尾再添加 \t 和 label,不然代码会把所有 \t 后的文本统一当作标签来处理(我是个小白,这就是我一开始犯的错误):

text        \t        label

        划分数据集,分别为train.txt和valid.txt,格式都同上所示。

三,模型训练

        以上都准备完成后就可以准备模型训练了,以下就是我参考上面那位作者的训练代码。

  1. from pytorch_pretrained_bert import BertModel, BertTokenizer
  2. from torch import nn
  3. import numpy as np
  4. import torch
  5. from shutil import copyfile
  6. import os
  7. from tqdm import tqdm
  8. import time
  9. from datetime import timedelta
  10. import pickle as pkl
  11. import re
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from sklearn import metrics
  15. from pytorch_pretrained_bert.optimization import BertAdam
  16. with open('./bert-base-uncased/config.json', 'w') as F:
  17. F.write('''
  18. {
  19. "architectures": [
  20. "BertForMaskedLM"
  21. ],
  22. "attention_probs_dropout_prob": 0.1,
  23. "hidden_act": "gelu",
  24. "hidden_dropout_prob": 0.1,
  25. "hidden_size": 768,
  26. "initializer_range": 0.02,
  27. "intermediate_size": 3072,
  28. "layer_norm_eps": 1e-12,
  29. "max_position_embeddings": 512,
  30. "model_type": "bert",
  31. "num_attention_heads": 12,
  32. "num_hidden_layers": 12,
  33. "pad_token_id": 0,
  34. "type_vocab_size": 2,
  35. "vocab_size": 30522
  36. }
  37. ''')
  38. class Config(object):
  39. """配置参数"""
  40. def __init__(self, dataset):
  41. # 模型名称
  42. self.model_name = "bertrnn"
  43. # 训练集
  44. self.train_path = './data/train.txt'
  45. # 校验集
  46. self.dev_path = './data/valid.txt'
  47. # 测试集
  48. self.test_path = './data/test.txt'
  49. # dataset
  50. # self.datasetpkl = dataset + '/data/dataset.pkl'
  51. # 类别名单
  52. self.class_list = [0,1,2,3]
  53. # 模型保存路径
  54. self.save_path = dataset + self.model_name + '.ckpt'
  55. # 运行设备
  56. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  57. # 若超过1000bacth效果还没有提升,提前结束训练
  58. self.require_improvment = 1000
  59. # 类别数量
  60. self.num_classes = len(self.class_list)
  61. # epoch数
  62. self.num_epochs = 5
  63. # batch_size
  64. self.batch_size =64 #显卡内存不足要调小batchsize,最小1或者2都可以,就是速度很慢
  65. # 序列长度
  66. self.pad_size = 73
  67. # 学习率
  68. self.learning_rate = 1e-5
  69. # 预训练位置
  70. self.bert_path = './bert-base-uncased'
  71. # bert的 tokenizer
  72. self.tokenizer = BertTokenizer.from_pretrained((self.bert_path))
  73. # Bert的隐藏层数量
  74. self.hidden_size = 768
  75. # droptout
  76. self.dropout = 0.1
  77. self.datasetpkl = dataset + 'datasetqq.pkl'
  78. PAD, CLS = '[PAD]', '[CLS]'
  79. def load_dataset(file_path, config):
  80. '''
  81. :param file_path:
  82. :param config:
  83. :return: ids,label,len(ids),mask
  84. '''
  85. contents = []
  86. with open(file_path, 'r', encoding='utf-8') as f:
  87. for line in tqdm(f):
  88. line = line.strip()
  89. if not line: continue
  90. content, label = line.split('\t')
  91. content = re.sub(r"https?://\S+", "", content)
  92. content = re.sub(r"what's", "what is", content)
  93. content = re.sub(r"Won't", "will not", content)
  94. content = re.sub(r"can't", "can not", content)
  95. content = re.sub(r"\'s", " ", content)
  96. content = re.sub(r"\'ve", " have", content)
  97. content = re.sub(r"n't", " not", content)
  98. content = re.sub(r"i'm", "i am", content)
  99. content = re.sub(r"\'re", " are", content)
  100. content = re.sub(r"\'d", " would", content)
  101. content = re.sub(r"\'ll", " will", content)
  102. content = re.sub(r"e - mail", "email", content)
  103. content = re.sub("\d+ ", "NUM", content)
  104. content = re.sub(r"<br />", '', content)
  105. content = re.sub(r'[\u0000-\u0019\u0021-\u0040\u007a-\uffff]', '', content) # 去掉非空格和非字母
  106. token = config.tokenizer.tokenize(content) # 切词
  107. token = [CLS] + token
  108. seq_len = len(token)
  109. mask = []
  110. token_ids = config.tokenizer.convert_tokens_to_ids(token) # 把切好的字转化成id
  111. pad_size = config.pad_size
  112. if pad_size:
  113. if len(token) < pad_size:
  114. mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
  115. token_ids = token_ids + ([0] * (pad_size - len(token)))
  116. else:
  117. mask = [1] * pad_size
  118. token_ids = token_ids[:pad_size]
  119. seq_len = pad_size
  120. contents.append((token_ids, int(label), seq_len, mask))
  121. return contents
  122. def build_dataset(config):
  123. '''
  124. :param config:
  125. :return: train,dev
  126. '''
  127. if os.path.exists(config.datasetpkl):
  128. dataset = pkl.load(open(config.datasetpkl, 'rb'))
  129. train = dataset['train']
  130. dev = dataset['dev']
  131. else:
  132. train = load_dataset(config.train_path, config)
  133. dev = load_dataset(config.dev_path, config)
  134. dataset = {}
  135. dataset['train'] = train
  136. dataset['dev'] = dev
  137. pkl.dump(dataset, open(config.datasetpkl, 'wb'))
  138. return train, dev
  139. class DatasetIterator:
  140. def __init__(self, dataset, batch_size, device):
  141. self.batch_size = batch_size
  142. self.dataset = dataset
  143. self.n_batch = len(dataset) // batch_size # batch 个数
  144. self.device = device
  145. self.residuce = False
  146. if len(dataset) % self.n_batch != 0:
  147. self.residuce = True # 如果句子个数除以batch个数不能整除,表示最后一个batch size 比之前的少
  148. self.index = 0 # 初始从第一个批次开始
  149. def _to_tensor(self, datas):
  150. x = torch.LongTensor([item[0] for item in datas]).to(self.device)
  151. y = torch.LongTensor([item[1] for item in datas]).to(self.device)
  152. seq_len = torch.LongTensor([item[2] for item in datas]).to(self.device)
  153. mask = torch.LongTensor([item[3] for item in datas]).to(self.device)
  154. return (x, seq_len, mask), y
  155. def __next__(self):
  156. if self.residuce and self.index == self.n_batch:
  157. '''如果没有整除尽并且是最后一个batch'''
  158. batches = self.dataset[self.index * self.batch_size:len(self.dataset)]
  159. self.index += 1
  160. batches = self._to_tensor(batches)
  161. return batches
  162. elif self.index > self.n_batch:
  163. self.index = 0
  164. raise StopIteration
  165. elif self.index == self.n_batch and not self.residuce:
  166. self.index = 0
  167. raise StopIteration
  168. else:
  169. batches = self.dataset[self.index * self.batch_size:(self.index + 1) * self.batch_size]
  170. self.index += 1
  171. batches = self._to_tensor(batches)
  172. return batches
  173. def __iter__(self):
  174. return self
  175. def __len__(self):
  176. if self.residuce:
  177. return self.n_batch + 1
  178. else:
  179. return self.n_batch
  180. def build_iterator(dataset, config):
  181. iter = DatasetIterator(dataset, config.batch_size, config.device)
  182. return iter
  183. def get_time_dif(start_time):
  184. '''
  185. 获取已经使用时间
  186. :param start_time:
  187. :return:
  188. '''
  189. end_time = time.time()
  190. time_dif = end_time - start_time
  191. return timedelta(seconds=int(time_dif))
  192. class Model(nn.Module):
  193. def __init__(self, config):
  194. super(Model, self).__init__()
  195. self.bert = BertModel.from_pretrained(config.bert_path)
  196. for param in self.bert.parameters():
  197. param.requires_grad = True
  198. self.fc = nn.Linear(config.hidden_size, config.num_classes)
  199. self.dropout = nn.Dropout(p=config.dropout)
  200. def forward(self, x):
  201. '''
  202. :param x:[input_ids,seq_len,mask]
  203. :return:
  204. '''
  205. context = x[0] # [batch_size,seq_len]
  206. token_type_ids = x[1]
  207. mask = x[2] # 对补零的单纯进行遮挡操作 [batch_size,seq_len]
  208. _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
  209. output = self.fc(pooled)
  210. output = self.dropout(output)
  211. return output
  212. def train(config, model, train_iter, dev_iter):
  213. '''
  214. :param config:
  215. :param model:
  216. :param train_iter:
  217. :param test_iter:
  218. :return:
  219. '''
  220. start_time = time.time()
  221. # 启动batchNormal 和dropout
  222. model.train()
  223. # 拿到所有mode中的参数
  224. param_optimizer = list(model.named_parameters())
  225. # 不需要衰减的参数 layernormal 不需要衰减
  226. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  227. optimizer_grouped_parameters = [
  228. {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  229. {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.00}
  230. ]
  231. optimizer = BertAdam(
  232. params=optimizer_grouped_parameters,
  233. lr=config.learning_rate,
  234. warmup=0.05,
  235. t_total=len(train_iter) * config.num_classes
  236. )
  237. total_batch = 0 # 记录进行了多少batch
  238. dev_best_loss = float('inf')
  239. dev_best_acc = 0
  240. last_improve = 0 # 记录上次校验集loss下降的batch数
  241. flag = False # 记录是否很多次没有效果提升
  242. for epoch in range(config.num_epochs):
  243. print('Epoch[{}/{}]'.format(epoch + 1, config.num_epochs))
  244. for i, (trains, labels) in enumerate(train_iter):
  245. outputs = model(trains)
  246. model.zero_grad()
  247. loss = F.cross_entropy(outputs, labels)
  248. loss.backward()
  249. optimizer.step()
  250. if total_batch % 5 == 0:
  251. true = labels.data.cpu()
  252. predict = torch.max(outputs.data, 1)[1].cpu()
  253. train_acc = metrics.accuracy_score(true, predict)
  254. dev_acc, dev_loss = evaluate(config, model, dev_iter)
  255. # if dev_loss < dev_best_loss:
  256. if dev_best_acc < dev_acc:
  257. # dev_best_loss = dev_loss
  258. dev_best_acc = dev_acc
  259. torch.save(model.state_dict(), config.save_path)
  260. improve = '*'
  261. last_improve = total_batch
  262. else:
  263. improve = ''
  264. time_dif = get_time_dif(start_time)
  265. msg = 'Iter:{0:>6},Train Loss{1:>5.2},Train Acc{2:>6.2},Val Loss{3:>5.2},Val Acc:{4:>6.2%},Time:{5} {6}'
  266. print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
  267. model.train()
  268. total_batch += 1
  269. if total_batch - last_improve > config.require_improvment:
  270. print('再检验数据集上已经很长时间没有提升了,模型自动停止训练')
  271. flag = True
  272. break
  273. if flag:
  274. break
  275. # torch.save(model.state_dict(), config.save_path)
  276. def evaluate(config, model, dev_iter, test=False):
  277. '''
  278. 验证
  279. :param config:
  280. :param model:
  281. :param dev_iter:
  282. :return:
  283. '''
  284. model.eval()
  285. loss_total = 0
  286. predict_all = np.array([], dtype=int)
  287. labels_all = np.array([], dtype=int)
  288. with torch.no_grad():
  289. for texts, labels in dev_iter:
  290. outputs = model(texts)
  291. loss = F.cross_entropy(outputs, labels)
  292. loss = loss.item()
  293. loss_total += loss
  294. labels = labels.data.cpu().numpy()
  295. predict = torch.max(outputs.data, 1)[1].cpu().numpy()
  296. labels_all = np.append(labels_all, labels)
  297. predict_all = np.append(predict_all, predict)
  298. acc = metrics.accuracy_score(labels_all, predict_all)
  299. if test:
  300. report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
  301. confusion = metrics.confusion_matrix(labels_all, predict_all)
  302. return acc, loss_total / len(dev_iter), report, confusion
  303. return acc, loss_total / len(dev_iter)
  304. config = Config('dataset')
  305. train_data, dev_data = build_dataset(config)
  306. train_iter = build_iterator(train_data, config) #
  307. dev_iter = build_iterator(dev_data, config)
  308. model = Model(config).to(config.device)
  309. train(config, model, train_iter, dev_iter)

        我自己删减了一些代码,因为我的test数据集很多,所有打算先把模型给训练出来再用test.py来预测结果。其中用自己的数据集要修改的地方大致如下:

1,第一步中的下载config文件的路径

with open('./bert-base-uncased/config.json', 'w') as F: 

2,以下三个数据集文件的路径

# 训练集
self.train_path =  './data/train.txt'
# 校验集
self.dev_path = './data/valid.txt'
# 测试集
self.test_path = './data/test.txt'

3,自己的数据集类别,和数据集内的label对应上

# 类别名单
self.class_list = [0,1,2,3]

4,第一步中下载的Bert文件的路径

self.bert_path = './bert-base-uncased'

5,其他要修改的话就是batchsize或者学习率自己调一调了

四,模型预测

        以下是我的模型预测代码:

  1. from pytorch_pretrained_bert import BertModel, BertTokenizer
  2. from torch import nn
  3. import numpy as np
  4. import torch
  5. from shutil import copyfile
  6. import os
  7. from tqdm import tqdm
  8. import time
  9. from datetime import timedelta
  10. import pickle as pkl
  11. import re
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from sklearn import metrics
  15. from pytorch_pretrained_bert.optimization import BertAdam
  16. with open('./bert-base-uncased/config.json', 'w') as F:
  17. F.write('''
  18. {
  19. "architectures": [
  20. "BertForMaskedLM"
  21. ],
  22. "attention_probs_dropout_prob": 0.1,
  23. "hidden_act": "gelu",
  24. "hidden_dropout_prob": 0.1,
  25. "hidden_size": 768,
  26. "initializer_range": 0.02,
  27. "intermediate_size": 3072,
  28. "layer_norm_eps": 1e-12,
  29. "max_position_embeddings": 512,
  30. "model_type": "bert",
  31. "num_attention_heads": 12,
  32. "num_hidden_layers": 12,
  33. "pad_token_id": 0,
  34. "type_vocab_size": 2,
  35. "vocab_size": 30522
  36. }
  37. ''')
  38. class Config(object):
  39. """配置参数"""
  40. def __init__(self, dataset):
  41. # 模型名称
  42. self.model_name = "bertrnn"
  43. # 训练集
  44. self.train_path = './data/train.txt'
  45. # 校验集
  46. self.dev_path = './data/valid.txt'
  47. # 测试集
  48. self.test_path = './data/test.txt'
  49. # dataset
  50. # self.datasetpkl = dataset + '/data/dataset.pkl'
  51. # 类别名单
  52. self.class_list = [0,1,2,3]
  53. # 模型保存路径
  54. self.save_path = dataset + self.model_name + '.ckpt'
  55. # 运行设备
  56. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  57. # 若超过1000bacth效果还没有提升,提前结束训练
  58. self.require_improvment = 1000
  59. # 类别数量
  60. self.num_classes = len(self.class_list)
  61. # epoch数
  62. self.num_epochs = 5
  63. # batch_size
  64. self.batch_size = 2#128
  65. # 序列长度
  66. self.pad_size = 73
  67. # 学习率
  68. self.learning_rate = 1e-5
  69. # 预训练位置
  70. self.bert_path = './bert-base-uncased'
  71. # bert的 tokenizer
  72. self.tokenizer = BertTokenizer.from_pretrained((self.bert_path))
  73. # Bert的隐藏层数量
  74. self.hidden_size = 768
  75. # droptout
  76. self.dropout = 0.1
  77. self.datasetpkl = dataset + 'datasetqq.pkl'
  78. PAD, CLS = '[PAD]', '[CLS]'
  79. def load_dataset(file_path, config):
  80. '''
  81. :param file_path:
  82. :param config:
  83. :return: ids,label,len(ids),mask
  84. '''
  85. contents = []
  86. with open(file_path, 'r', encoding='utf-8') as f:
  87. for line in tqdm(f):
  88. line = line.strip()
  89. if not line: continue
  90. content, label = line.split('\t')
  91. content = re.sub(r"https?://\S+", "", content)
  92. content = re.sub(r"what's", "what is", content)
  93. content = re.sub(r"Won't", "will not", content)
  94. content = re.sub(r"can't", "can not", content)
  95. content = re.sub(r"\'s", " ", content)
  96. content = re.sub(r"\'ve", " have", content)
  97. content = re.sub(r"n't", " not", content)
  98. content = re.sub(r"i'm", "i am", content)
  99. content = re.sub(r"\'re", " are", content)
  100. content = re.sub(r"\'d", " would", content)
  101. content = re.sub(r"\'ll", " will", content)
  102. content = re.sub(r"e - mail", "email", content)
  103. content = re.sub("\d+ ", "NUM", content)
  104. content = re.sub(r"<br />", '', content)
  105. content = re.sub(r'[\u0000-\u0019\u0021-\u0040\u007a-\uffff]', '', content) # 去掉非空格和非字母
  106. token = config.tokenizer.tokenize(content) # 切词
  107. token = [CLS] + token
  108. seq_len = len(token)
  109. mask = []
  110. token_ids = config.tokenizer.convert_tokens_to_ids(token) # 把切好的字转化成id
  111. pad_size = config.pad_size
  112. if pad_size:
  113. if len(token) < pad_size:
  114. mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
  115. token_ids = token_ids + ([0] * (pad_size - len(token)))
  116. else:
  117. mask = [1] * pad_size
  118. token_ids = token_ids[:pad_size]
  119. seq_len = pad_size
  120. contents.append((token_ids, int(label), seq_len, mask))
  121. return contents
  122. class DatasetIterator:
  123. def __init__(self, dataset, batch_size, device):
  124. self.batch_size = batch_size
  125. self.dataset = dataset
  126. self.n_batch = len(dataset) // batch_size # batch 个数
  127. self.device = device
  128. self.residuce = False
  129. if len(dataset) % self.n_batch != 0:
  130. self.residuce = True # 如果句子个数除以batch个数不能整除,表示最后一个batch size 比之前的少
  131. self.index = 0 # 初始从第一个批次开始
  132. def _to_tensor(self, datas):
  133. x = torch.LongTensor([item[0] for item in datas]).to(self.device)
  134. y = torch.LongTensor([item[1] for item in datas]).to(self.device)
  135. seq_len = torch.LongTensor([item[2] for item in datas]).to(self.device)
  136. mask = torch.LongTensor([item[3] for item in datas]).to(self.device)
  137. return (x, seq_len, mask), y
  138. def __next__(self):
  139. if self.residuce and self.index == self.n_batch:
  140. '''如果没有整除尽并且是最后一个batch'''
  141. batches = self.dataset[self.index * self.batch_size:len(self.dataset)]
  142. self.index += 1
  143. batches = self._to_tensor(batches)
  144. return batches
  145. elif self.index > self.n_batch:
  146. self.index = 0
  147. raise StopIteration
  148. elif self.index == self.n_batch and not self.residuce:
  149. self.index = 0
  150. raise StopIteration
  151. else:
  152. batches = self.dataset[self.index * self.batch_size:(self.index + 1) * self.batch_size]
  153. self.index += 1
  154. batches = self._to_tensor(batches)
  155. return batches
  156. def __iter__(self):
  157. return self
  158. def __len__(self):
  159. if self.residuce:
  160. return self.n_batch + 1
  161. else:
  162. return self.n_batch
  163. def build_iterator(dataset, config):
  164. iter = DatasetIterator(dataset, config.batch_size, config.device)
  165. return iter
  166. class Model(nn.Module):
  167. def __init__(self, config):
  168. super(Model, self).__init__()
  169. self.bert = BertModel.from_pretrained(config.bert_path)
  170. for param in self.bert.parameters():
  171. param.requires_grad = True
  172. self.fc = nn.Linear(config.hidden_size, config.num_classes)
  173. self.dropout = nn.Dropout(p=config.dropout)
  174. def forward(self, x):
  175. '''
  176. :param x:[input_ids,seq_len,mask]
  177. :return:
  178. '''
  179. context = x[0] # [batch_size,seq_len]
  180. token_type_ids = x[1]
  181. mask = x[2] # 对补零的单纯进行遮挡操作 [batch_size,seq_len]
  182. _, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
  183. output = self.fc(pooled)
  184. output = self.dropout(output)
  185. return output
  186. config = Config('dataset')
  187. test_data = load_dataset(config.test_path, config)
  188. test_iter = build_iterator(test_data, config)
  189. model = Model(config).to(config.device)
  190. def test(config, model, test_iter):
  191. '''
  192. 模型测试
  193. :param config:
  194. :param model:
  195. :param test_iter:
  196. :return:
  197. '''
  198. model.load_state_dict(torch.load(config.save_path)) # 加载模型
  199. model.eval()
  200. index_labels = []
  201. for texts, index in test_iter:
  202. outputs = model(texts)
  203. preds = torch.max(outputs.data, 1)[1]
  204. index = index.unsqueeze(1)
  205. preds = preds.unsqueeze(1)
  206. index_label = torch.cat([index,preds],dim=1).to('cpu').numpy()
  207. index_labels.append(index_label)
  208. return index_labels
  209. index_labels = test(config,model,test_iter)
  210. print('label:')
  211. print(index_labels)
  212. import pandas as pd
  213. df = np.concatenate(index_labels,axis=0)
  214. df = pd.DataFrame(df,columns=('id','target'))
  215. df = df.sort_values(by='id')
  216. df.to_csv('./test.csv',index = None)

        以上需要注意的一个点是test数据集大小和batchsize有点关系,test数据集内的文本数量不能低于batchsize,不然会在下面代码中产生除0错误,这只需要把batch_size改成1就行。

self.n_batch = len(dataset) // batch_size  # batch 个数

......
if len(dataset) % self.n_batch != 0:

五,总结

        我已成功运行出上面的代码,如果数据集很大的话比如我的数据集有12万条文本,那么训练一轮应该就够了,当时一轮迭代后的验证集上的精度已经达到了90%,可根据个人情况调整。

        至于其他出现的问题可能是深度学习环境不匹配导致的了,这个只能调整自己的环境了

        本人也是个深度学习小白,代码能力不足仍需提升,希望我的以上经验可以帮助到其他人。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/235656
推荐阅读
相关标签
  

闽ICP备14008679号