当前位置:   article > 正文

Bert 源码(pytorch)超详细的解读_bert源码分析

bert源码分析

model.py

transformers的bert源码的解读 

  1. # coding=utf-8
  2. from __future__ import absolute_import, division, print_function, unicode_literals
  3. import copy
  4. import json
  5. import logging
  6. import math
  7. import os
  8. import shutil
  9. import tarfile
  10. import tempfile
  11. import sys
  12. from io import open
  13. import torch
  14. from torch import nn
  15. from torch.nn import CrossEntropyLoss
  16. from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
  17. logger = logging.getLogger(__name__)
  18. # uncased不分大小写 multilingual多语种的
  19. PRETRAINED_MODEL_ARCHIVE_MAP = {
  20. 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
  21. 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
  22. 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
  23. 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
  24. 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
  25. 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
  26. 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
  27. }
  28. BERT_CONFIG_NAME = 'bert_config.json'
  29. TF_WEIGHTS_NAME = 'model.ckpt'
  30. def load_tf_weights_in_bert(model, tf_checkpoint_path):
  31. """ Load tf checkpoints in a pytorch model
  32. """
  33. try:
  34. import re
  35. import numpy as np
  36. import tensorflow as tf
  37. except ImportError:
  38. print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
  39. "https://www.tensorflow.org/install/ for installation instructions.")
  40. raise
  41. tf_path = os.path.abspath(tf_checkpoint_path)
  42. print("Converting TensorFlow checkpoint from {}".format(tf_path))
  43. # Load weights from TF model
  44. init_vars = tf.train.list_variables(tf_path)
  45. names = []
  46. arrays = []
  47. for name, shape in init_vars:
  48. print("Loading TF weight {} with shape {}".format(name, shape))
  49. array = tf.train.load_variable(tf_path, name)
  50. names.append(name)
  51. arrays.append(array)
  52. for name, array in zip(names, arrays):
  53. name = name.split('/')
  54. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  55. # which are not required for using pretrained model
  56. if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
  57. print("Skipping {}".format("/".join(name)))
  58. continue
  59. pointer = model
  60. for m_name in name:
  61. if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
  62. l = re.split(r'_(\d+)', m_name)
  63. else:
  64. l = [m_name]
  65. if l[0] == 'kernel' or l[0] == 'gamma':
  66. pointer = getattr(pointer, 'weight')
  67. elif l[0] == 'output_bias' or l[0] == 'beta':
  68. pointer = getattr(pointer, 'bias')
  69. elif l[0] == 'output_weights':
  70. pointer = getattr(pointer, 'weight')
  71. elif l[0] == 'squad':
  72. pointer = getattr(pointer, 'classifier')
  73. else:
  74. try:
  75. pointer = getattr(pointer, l[0])
  76. except AttributeError:
  77. print("Skipping {}".format("/".join(name)))
  78. continue
  79. if len(l) >= 2:
  80. num = int(l[1])
  81. pointer = pointer[num]
  82. if m_name[-11:] == '_embeddings':
  83. pointer = getattr(pointer, 'weight')
  84. elif m_name == 'kernel':
  85. array = np.transpose(array)
  86. try:
  87. assert pointer.shape == array.shape
  88. except AssertionError as e:
  89. e.args += (pointer.shape, array.shape)
  90. raise
  91. print("Initialize PyTorch weight {}".format(name))
  92. pointer.data = torch.from_numpy(array)
  93. return model
  94. def gelu(x): # bert的激活函数
  95. """Implementation of the gelu activation function.
  96. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
  97. 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  98. Also see https://arxiv.org/abs/1606.08415
  99. """
  100. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  101. def swish(x):
  102. return x * torch.sigmoid(x)
  103. ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
  104. class BertConfig(object):
  105. """Configuration class to store the configuration of a `BertModel`.
  106. """
  107. def __init__(self,
  108. vocab_size_or_config_json_file,
  109. hidden_size=768,
  110. num_hidden_layers=12,
  111. num_attention_heads=12,
  112. intermediate_size=3072,
  113. hidden_act="gelu",
  114. hidden_dropout_prob=0.1,
  115. attention_probs_dropout_prob=0.1,
  116. max_position_embeddings=512,
  117. type_vocab_size=2,
  118. initializer_range=0.02):
  119. """Constructs BertConfig.
  120. Args:
  121. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
  122. hidden_size: Size of the encoder layers and the pooler layer.
  123. num_hidden_layers: Number of hidden layers in the Transformer encoder.
  124. num_attention_heads: Number of attention heads for each attention layer in
  125. the Transformer encoder.
  126. intermediate_size: The size of the "intermediate" (i.e., feed-forward)
  127. layer in the Transformer encoder.
  128. hidden_act: The non-linear activation function (function or string) in the
  129. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
  130. hidden_dropout_prob: The dropout probabilitiy for all fully connected
  131. layers in the embeddings, encoder, and pooler.
  132. attention_probs_dropout_prob: The dropout ratio for the attention
  133. probabilities.
  134. max_position_embeddings: The maximum sequence length that this model might
  135. ever be used with. Typically set this to something large just in case
  136. (e.g., 512 or 1024 or 2048).
  137. type_vocab_size: The vocabulary size of the `token_type_ids` passed into
  138. `BertModel`.
  139. initializer_range: The sttdev of the truncated_normal_initializer for
  140. initializing all weight matrices.
  141. """
  142. if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
  143. and isinstance(vocab_size_or_config_json_file, unicode)):
  144. with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
  145. json_config = json.loads(reader.read())
  146. for key, value in json_config.items():
  147. self.__dict__[key] = value
  148. elif isinstance(vocab_size_or_config_json_file, int): # isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()
  149. self.vocab_size = vocab_size_or_config_json_file
  150. self.hidden_size = hidden_size
  151. self.num_hidden_layers = num_hidden_layers
  152. self.num_attention_heads = num_attention_heads
  153. self.hidden_act = hidden_act
  154. self.intermediate_size = intermediate_size
  155. self.hidden_dropout_prob = hidden_dropout_prob
  156. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  157. self.max_position_embeddings = max_position_embeddings
  158. self.type_vocab_size = type_vocab_size
  159. self.initializer_range = initializer_range
  160. else:
  161. raise ValueError("First argument must be either a vocabulary size (int)"
  162. "or the path to a pretrained model config file (str)")
  163. @classmethod
  164. def from_dict(cls, json_object):
  165. """Constructs a `BertConfig` from a Python dictionary of parameters."""
  166. config = BertConfig(vocab_size_or_config_json_file=-1)
  167. for key, value in json_object.items():
  168. config.__dict__[key] = value
  169. return config
  170. @classmethod
  171. def from_json_file(cls, json_file):
  172. """Constructs a `BertConfig` from a json file of parameters."""
  173. with open(json_file, "r", encoding='utf-8') as reader:
  174. text = reader.read()
  175. return cls.from_dict(json.loads(text))
  176. def __repr__(self):
  177. return str(self.to_json_string())
  178. def to_dict(self):
  179. """Serializes this instance to a Python dictionary."""
  180. output = copy.deepcopy(self.__dict__)
  181. return output
  182. def to_json_string(self):
  183. """Serializes this instance to a JSON string."""
  184. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  185. def to_json_file(self, json_file_path):
  186. """ Save this instance to a json file."""
  187. with open(json_file_path, "w", encoding='utf-8') as writer:
  188. writer.write(self.to_json_string())
  189. try:
  190. from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
  191. except ImportError:
  192. logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
  193. class BertLayerNorm(nn.Module):
  194. def __init__(self, hidden_size, eps=1e-12):
  195. """Construct a layernorm module in the TF style (epsilon inside the square root).
  196. """
  197. super(BertLayerNorm, self).__init__()
  198. self.weight = nn.Parameter(torch.ones(hidden_size))
  199. self.bias = nn.Parameter(torch.zeros(hidden_size))
  200. self.variance_epsilon = eps
  201. def forward(self, x):
  202. u = x.mean(-1, keepdim=True) # 求均值
  203. s = (x - u).pow(2).mean(-1, keepdim=True) # 求方差
  204. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  205. return self.weight * x + self.bias
  206. class BertEmbeddings(nn.Module):
  207. """Construct the embeddings from word, position and token_type embeddings.
  208. """
  209. def __init__(self, config):
  210. super(BertEmbeddings, self).__init__()
  211. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) # 21128,768
  212. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # 521,76
  213. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # 2,768
  214. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  215. # any TensorFlow checkpoint file
  216. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  217. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  218. '''
  219. BERT的模型结构里没有递归或循环,为了使模型能有效的利用模型特征,我们需要加入序列中各个token相对位置或token在序列中的绝对位置信息
  220. 做法是对不同位置随机初始化一个position embedding,将其加到token embedding上输入模型,作为参数进行学习训练
  221. 绝对位置编码的优点是简单,但位置之间没有约束关系,我们只能期待模型隐形的学习它们之间的关系
  222. 在transformer中,提出了相对位置编码
  223. '''
  224. def forward(self, input_ids, token_type_ids=None):
  225. # print('input',input_ids.size()) torch.Size([128, 32]) 输入为[batch_size, seq_len]
  226. # [[ 101, 860, 7741, ..., 0, 0, 0],..
  227. # [ 101, 8183, 2399, ..., 0, 0, 0]] 101就是[CLS]对于的token_id
  228. seq_length = input_ids.size(1)
  229. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  230. # 初始化position [0,1,..,31]
  231. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  232. # [128, 32] [[ 0, 1, 2, ..., 29, 30, 31],..
  233. # [ 0, 1, 2, ..., 29, 30, 31]] (128个)
  234. if token_type_ids is None: # 如果不需要区分token type,也就是说只有一个句子输入时
  235. token_type_ids = torch.zeros_like(input_ids)
  236. words_embeddings = self.word_embeddings(input_ids)
  237. position_embeddings = self.position_embeddings(position_ids)
  238. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  239. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  240. # 四个shape都是[128, 32, 768],即(batch_size, sequence_length, hidden_size)
  241. embeddings = self.LayerNorm(embeddings)
  242. embeddings = self.dropout(embeddings)
  243. # 为什么这里用layernorm+dropout 而不是batchnorm?
  244. # https://www.zhihu.com/question/395811291/answer/1260290120
  245. return embeddings
  246. class BertSelfAttention(nn.Module):
  247. def __init__(self, config):
  248. super(BertSelfAttention, self).__init__()
  249. if config.hidden_size % config.num_attention_heads != 0:
  250. raise ValueError(
  251. "The hidden size (%d) is not a multiple of the number of attention "
  252. "heads (%d)" % (config.hidden_size, config.num_attention_heads))
  253. # hidden_size必须是num_attention_heads的整数倍,以这里的bert-base为例,
  254. # 每个attention包含12个head,hidden_size768,所以每个head大小即attention_head_size=768/12=64
  255. self.num_attention_heads = config.num_attention_heads
  256. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  257. self.all_head_size = self.num_attention_heads * self.attention_head_size
  258. # 12*64=768 之所以不用hidden_size,因为有剪枝prune_heads操作,这里代码没有写
  259. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  260. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  261. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  262. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  263. def transpose_for_scores(self, x):
  264. # 把hidden_size拆成多个头输出的形状,并且将中间两维转置以进行矩阵相乘;
  265. # x一般就是模型的输入,也就是我们刚才得到的embedding[128, 32, 768]
  266. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # [128, 32, 12, 64]
  267. # 注意这里的写法,x.size()[:-1]的意思是得到x的前两个纬度,shape为[128, 32]
  268. # print(new_x_shape)就为torch.Size([128, 32, 12, 64]) 这个变量就代表一个shape,它不是一个向量
  269. x = x.view(*new_x_shape) # [128, 32, 12, 64] 这里是使x变为new_x_shape这个形状
  270. # *是为什么:多看官方文档 https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
  271. return x.permute(0, 2, 1, 3)
  272. def forward(self, hidden_states, attention_mask):
  273. # attention_mask:[128,1,1,32]就是bert的输入mask扩充了两个纬度
  274. # 比如y=[6,3],shape为[2],y.unsqueeze(0),它的纬度就是[1,2],值为[[6,3]]
  275. mixed_query_layer = self.query(hidden_states)
  276. mixed_key_layer = self.key(hidden_states)
  277. mixed_value_layer = self.value(hidden_states)
  278. query_layer = self.transpose_for_scores(mixed_query_layer) # 这里纬度怎么相乘的可以自己算一下
  279. key_layer = self.transpose_for_scores(mixed_key_layer)
  280. value_layer = self.transpose_for_scores(mixed_value_layer)
  281. # (batch_size, num_attention_heads, sequence_length, attention_head_size)
  282. # Take the dot product between "query" and "key" to get the raw attention scores.
  283. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  284. # [128, 12, 32, 32] (batch_size, num_attention_heads, sequence_length, sequence_length)
  285. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  286. # 除以根号dk
  287. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  288. attention_scores = attention_scores + attention_mask # 这里的attention_mask是为了对句子进行padding
  289. # 但输入的mask不是[1, 1, 1, 1, 0, 0]这样的吗,形状[128, 32],按理应该相乘的
  290. # 但其实attention_mask被动过手脚了 (BertModel函数的extended_attention_mask处)
  291. # print一下 [[ -0., -0., -0., ..., -10000., -10000., -10000.]]],..
  292. # [[[ -0., -0., -0., ..., -10000., -10000., -10000.]]],..
  293. # 将原本为1的部分变为0,而原本为0的部分(即padding)变为一个较大的负数,这样相加就得到了一个较大的负值
  294. # 这样一来经过softmax操作以后这一项就会变成接近0的数,实现了padding的目的
  295. # Normalize the attention scores to probabilities.
  296. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  297. # This is actually dropping out entire tokens to attend to, which might
  298. # seem a bit unusual, but is taken from the original Transformer paper.
  299. attention_probs = self.dropout(attention_probs)
  300. context_layer = torch.matmul(attention_probs, value_layer)
  301. # (batch_size, num_attention_heads, sequence_length, attention_head_size)
  302. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  303. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # [128, 32, 768]
  304. context_layer = context_layer.view(*new_context_layer_shape) # 多头注意力concat
  305. return context_layer
  306. class BertSelfOutput(nn.Module):
  307. def __init__(self, config):
  308. super(BertSelfOutput, self).__init__()
  309. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  310. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  311. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  312. def forward(self, hidden_states, input_tensor):
  313. hidden_states = self.dense(hidden_states)
  314. hidden_states = self.dropout(hidden_states)
  315. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  316. # 残差连接 对应公式LayerNorm(s+sublayer(x)) 也就是Add&Norm
  317. return hidden_states
  318. class BertAttention(nn.Module):
  319. def __init__(self, config):
  320. super(BertAttention, self).__init__()
  321. self.self = BertSelfAttention(config)
  322. self.output = BertSelfOutput(config)
  323. def forward(self, input_tensor, attention_mask): # bert的两个输入
  324. # input_tensor.size():[128, 32, 768] attention_mask.size():[128, 1, 1, 32]
  325. self_output = self.self(input_tensor, attention_mask)
  326. attention_output = self.output(self_output, input_tensor)
  327. return attention_output
  328. class BertIntermediate(nn.Module):
  329. """
  330. 全连接+激活 中间层的目的是为了对齐维度
  331. 对应论文的FFN
  332. """
  333. def __init__(self, config):
  334. super(BertIntermediate, self).__init__()
  335. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  336. if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
  337. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  338. else:
  339. self.intermediate_act_fn = config.hidden_act
  340. def forward(self, hidden_states):
  341. hidden_states = self.dense(hidden_states)
  342. hidden_states = self.intermediate_act_fn(hidden_states)
  343. return hidden_states
  344. class BertOutput(nn.Module):
  345. """又是一个全连接+dropout+LayerNorm,还有一个残差连接residual connect"""
  346. def __init__(self, config):
  347. super(BertOutput, self).__init__()
  348. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  349. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  350. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  351. def forward(self, hidden_states, input_tensor):
  352. hidden_states = self.dense(hidden_states)
  353. hidden_states = self.dropout(hidden_states)
  354. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  355. return hidden_states
  356. class BertLayer(nn.Module):
  357. def __init__(self, config):
  358. super(BertLayer, self).__init__()
  359. self.attention = BertAttention(config)
  360. self.intermediate = BertIntermediate(config)
  361. self.output = BertOutput(config)
  362. def forward(self, hidden_states, attention_mask):
  363. attention_output = self.attention(hidden_states, attention_mask)
  364. intermediate_output = self.intermediate(attention_output)
  365. layer_output = self.output(intermediate_output, attention_output)
  366. return layer_output
  367. class BertEncoder(nn.Module):
  368. def __init__(self, config):
  369. super(BertEncoder, self).__init__()
  370. layer = BertLayer(config)
  371. self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
  372. # 多层encoder的写法
  373. def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
  374. all_encoder_layers = []
  375. for layer_module in self.layer:
  376. hidden_states = layer_module(hidden_states, attention_mask)
  377. if output_all_encoded_layers:
  378. all_encoder_layers.append(hidden_states)
  379. if not output_all_encoded_layers:
  380. all_encoder_layers.append(hidden_states)
  381. return all_encoder_layers
  382. class BertPooler(nn.Module):
  383. """
  384. 这一层只是简单地取出了句子的第一个token,即[CLS]对应的向量
  385. pooling还有其他方式,如avgpool,maxpool
  386. """
  387. def __init__(self, config):
  388. super(BertPooler, self).__init__()
  389. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  390. self.activation = nn.Tanh()
  391. def forward(self, hidden_states):
  392. # We "pool" the model by simply taking the hidden state corresponding
  393. # to the first token.
  394. first_token_tensor = hidden_states[:, 0]
  395. pooled_output = self.dense(first_token_tensor)
  396. pooled_output = self.activation(pooled_output)
  397. return pooled_output
  398. class BertPredictionHeadTransform(nn.Module):
  399. def __init__(self, config):
  400. super(BertPredictionHeadTransform, self).__init__()
  401. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  402. if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
  403. self.transform_act_fn = ACT2FN[config.hidden_act]
  404. else:
  405. self.transform_act_fn = config.hidden_act
  406. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  407. def forward(self, hidden_states):
  408. hidden_states = self.dense(hidden_states)
  409. hidden_states = self.transform_act_fn(hidden_states)
  410. hidden_states = self.LayerNorm(hidden_states)
  411. return hidden_states
  412. class BertLMPredictionHead(nn.Module):
  413. def __init__(self, config, bert_model_embedding_weights):
  414. super(BertLMPredictionHead, self).__init__()
  415. self.transform = BertPredictionHeadTransform(config)
  416. # The output weights are the same as the input embeddings, but there is
  417. # an output-only bias for each token.
  418. self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
  419. bert_model_embedding_weights.size(0),
  420. bias=False)
  421. self.decoder.weight = bert_model_embedding_weights
  422. self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
  423. def forward(self, hidden_states):
  424. hidden_states = self.transform(hidden_states)
  425. hidden_states = self.decoder(hidden_states) + self.bias
  426. return hidden_states # [batch_size, seq_length, vocab_size],即预测每个句子每个词是什么类别的概率值
  427. class BertOnlyMLMHead(nn.Module):
  428. def __init__(self, config, bert_model_embedding_weights):
  429. super(BertOnlyMLMHead, self).__init__()
  430. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  431. def forward(self, sequence_output):
  432. prediction_scores = self.predictions(sequence_output)
  433. return prediction_scores
  434. class BertOnlyNSPHead(nn.Module):
  435. def __init__(self, config):
  436. super(BertOnlyNSPHead, self).__init__()
  437. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  438. def forward(self, pooled_output):
  439. seq_relationship_score = self.seq_relationship(pooled_output)
  440. return seq_relationship_score
  441. class BertPreTrainingHeads(nn.Module):
  442. def __init__(self, config, bert_model_embedding_weights):
  443. super(BertPreTrainingHeads, self).__init__()
  444. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  445. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  446. def forward(self, sequence_output, pooled_output):
  447. prediction_scores = self.predictions(sequence_output)
  448. seq_relationship_score = self.seq_relationship(pooled_output)
  449. return prediction_scores, seq_relationship_score
  450. class BertPreTrainedModel(nn.Module):
  451. """ An abstract class to handle weights initialization and
  452. a simple interface for dowloading and loading pretrained models.
  453. """
  454. def __init__(self, config, *inputs, **kwargs):
  455. super(BertPreTrainedModel, self).__init__()
  456. if not isinstance(config, BertConfig):
  457. raise ValueError(
  458. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  459. "To create a model from a Google pretrained model use "
  460. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  461. self.__class__.__name__, self.__class__.__name__
  462. ))
  463. self.config = config
  464. def init_bert_weights(self, module):
  465. """ Initialize the weights.
  466. """
  467. if isinstance(module, (nn.Linear, nn.Embedding)):
  468. # Slightly different from the TF version which uses truncated_normal for initialization
  469. # cf https://github.com/pytorch/pytorch/pull/5617
  470. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  471. elif isinstance(module, BertLayerNorm):
  472. module.bias.data.zero_()
  473. module.weight.data.fill_(1.0)
  474. if isinstance(module, nn.Linear) and module.bias is not None:
  475. module.bias.data.zero_()
  476. @classmethod
  477. def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
  478. """
  479. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
  480. Download and cache the pre-trained model file if needed.
  481. Params:
  482. pretrained_model_name_or_path: either:
  483. - a str with the name of a pre-trained model to load selected in the list of:
  484. . `bert-base-uncased`
  485. . `bert-large-uncased`
  486. . `bert-base-cased`
  487. . `bert-large-cased`
  488. . `bert-base-multilingual-uncased`
  489. . `bert-base-multilingual-cased`
  490. . `bert-base-chinese`
  491. - a path or url to a pretrained model archive containing:
  492. . `bert_config.json` a configuration file for the model
  493. . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
  494. - a path or url to a pretrained model archive containing:
  495. . `bert_config.json` a configuration file for the model
  496. . `model.chkpt` a TensorFlow checkpoint
  497. from_tf: should we load the weights from a locally saved TensorFlow checkpoint
  498. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
  499. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
  500. *inputs, **kwargs: additional input for the specific Bert class
  501. (ex: num_labels for BertForSequenceClassification)
  502. """
  503. state_dict = kwargs.get('state_dict', None)
  504. kwargs.pop('state_dict', None)
  505. cache_dir = kwargs.get('cache_dir', None)
  506. kwargs.pop('cache_dir', None)
  507. from_tf = kwargs.get('from_tf', False)
  508. kwargs.pop('from_tf', None)
  509. if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
  510. archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
  511. else:
  512. archive_file = pretrained_model_name_or_path
  513. # redirect to the cache, if necessary
  514. try:
  515. resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
  516. except EnvironmentError:
  517. logger.error(
  518. "Model name '{}' was not found in model name list ({}). "
  519. "We assumed '{}' was a path or url but couldn't find any file "
  520. "associated to this path or url.".format(
  521. pretrained_model_name_or_path,
  522. ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
  523. archive_file))
  524. return None
  525. if resolved_archive_file == archive_file:
  526. logger.info("loading archive file {}".format(archive_file))
  527. else:
  528. logger.info("loading archive file {} from cache at {}".format(
  529. archive_file, resolved_archive_file))
  530. tempdir = None
  531. if os.path.isdir(resolved_archive_file) or from_tf:
  532. serialization_dir = resolved_archive_file
  533. else:
  534. # Extract archive to temp dir
  535. tempdir = tempfile.mkdtemp()
  536. logger.info("extracting archive file {} to temp dir {}".format(
  537. resolved_archive_file, tempdir))
  538. with tarfile.open(resolved_archive_file, 'r:gz') as archive:
  539. archive.extractall(tempdir)
  540. serialization_dir = tempdir
  541. # Load config
  542. config_file = os.path.join(serialization_dir, CONFIG_NAME)
  543. if not os.path.exists(config_file):
  544. # Backward compatibility with old naming format
  545. config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
  546. config = BertConfig.from_json_file(config_file)
  547. logger.info("Model config {}".format(config))
  548. # Instantiate model.
  549. model = cls(config, *inputs, **kwargs)
  550. if state_dict is None and not from_tf:
  551. weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
  552. state_dict = torch.load(weights_path, map_location='cpu')
  553. if tempdir:
  554. # Clean up temp dir
  555. shutil.rmtree(tempdir)
  556. if from_tf:
  557. # Directly load from a TensorFlow checkpoint
  558. weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
  559. return load_tf_weights_in_bert(model, weights_path)
  560. # Load from a PyTorch state_dict
  561. old_keys = []
  562. new_keys = []
  563. for key in state_dict.keys():
  564. new_key = None
  565. if 'gamma' in key:
  566. new_key = key.replace('gamma', 'weight')
  567. if 'beta' in key:
  568. new_key = key.replace('beta', 'bias')
  569. if new_key:
  570. old_keys.append(key)
  571. new_keys.append(new_key)
  572. for old_key, new_key in zip(old_keys, new_keys):
  573. state_dict[new_key] = state_dict.pop(old_key)
  574. missing_keys = []
  575. unexpected_keys = []
  576. error_msgs = []
  577. # copy state_dict so _load_from_state_dict can modify it
  578. metadata = getattr(state_dict, '_metadata', None)
  579. state_dict = state_dict.copy()
  580. if metadata is not None:
  581. state_dict._metadata = metadata
  582. def load(module, prefix=''):
  583. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  584. module._load_from_state_dict(
  585. state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  586. for name, child in module._modules.items():
  587. if child is not None:
  588. load(child, prefix + name + '.')
  589. start_prefix = ''
  590. if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
  591. start_prefix = 'bert.'
  592. load(model, prefix=start_prefix)
  593. if len(missing_keys) > 0:
  594. logger.info("Weights of {} not initialized from pretrained model: {}".format(
  595. model.__class__.__name__, missing_keys))
  596. if len(unexpected_keys) > 0:
  597. logger.info("Weights from pretrained model not used in {}: {}".format(
  598. model.__class__.__name__, unexpected_keys))
  599. if len(error_msgs) > 0:
  600. raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
  601. model.__class__.__name__, "\n\t".join(error_msgs)))
  602. return model
  603. class BertModel(BertPreTrainedModel):
  604. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  605. Params:
  606. config: a BertConfig class instance with the configuration to build a new model
  607. Inputs:
  608. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  609. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  610. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  611. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  612. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  613. a `sentence B` token (see BERT paper for more details).
  614. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  615. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  616. input sequence length in the current batch. It's the mask that we typically use for attention when
  617. a batch has varying length sentences.
  618. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
  619. Outputs: Tuple of (encoded_layers, pooled_output)
  620. `encoded_layers`: controled by `output_all_encoded_layers` argument:
  621. - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
  622. of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
  623. encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  624. - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
  625. to the last attention block of shape [batch_size, sequence_length, hidden_size],
  626. `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
  627. classifier pretrained on top of the hidden state associated to the first character of the
  628. input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
  629. Example usage:
  630. ```python
  631. # Already been converted into WordPiece token ids
  632. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  633. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  634. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  635. config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  636. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  637. model = modeling.BertModel(config=config)
  638. all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  639. ```
  640. """
  641. def __init__(self, config):
  642. super(BertModel, self).__init__(config)
  643. self.embeddings = BertEmbeddings(config)
  644. self.encoder = BertEncoder(config)
  645. self.pooler = BertPooler(config)
  646. self.apply(self.init_bert_weights)
  647. def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
  648. if attention_mask is None:
  649. attention_mask = torch.ones_like(input_ids)
  650. if token_type_ids is None:
  651. token_type_ids = torch.zeros_like(input_ids)
  652. # We create a 3D attention mask from a 2D tensor mask.
  653. # Sizes are [batch_size, 1, 1, to_seq_length]
  654. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  655. # this attention mask is more simple than the triangular masking of causal attention
  656. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  657. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  658. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  659. # masked positions, this operation will create a tensor which is 0.0 for
  660. # positions we want to attend and -10000.0 for masked positions.
  661. # Since we are adding it to the raw scores before the softmax, this is
  662. # effectively the same as removing these entirely.
  663. extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
  664. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  665. embedding_output = self.embeddings(input_ids, token_type_ids)
  666. encoded_layers = self.encoder(embedding_output,
  667. extended_attention_mask,
  668. output_all_encoded_layers=output_all_encoded_layers)
  669. sequence_output = encoded_layers[-1] # 即句子对应的向量,也是CLS向量
  670. pooled_output = self.pooler(sequence_output)
  671. if not output_all_encoded_layers:
  672. encoded_layers = encoded_layers[-1]
  673. return encoded_layers, pooled_output
  674. class BertForPreTraining(BertPreTrainedModel):
  675. """BERT model with pre-training heads.
  676. This module comprises the BERT model followed by the two pre-training heads:
  677. - the masked language modeling head, and
  678. - the next sentence classification head.
  679. Params:
  680. config: a BertConfig class instance with the configuration to build a new model.
  681. Inputs:
  682. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  683. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  684. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  685. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  686. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  687. a `sentence B` token (see BERT paper for more details).
  688. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  689. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  690. input sequence length in the current batch. It's the mask that we typically use for attention when
  691. a batch has varying length sentences.
  692. `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  693. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  694. is only computed for the labels set in [0, ..., vocab_size]
  695. `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
  696. with indices selected in [0, 1].
  697. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  698. Outputs:
  699. if `masked_lm_labels` and `next_sentence_label` are not `None`:
  700. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  701. sentence classification loss.
  702. if `masked_lm_labels` or `next_sentence_label` is `None`:
  703. Outputs a tuple comprising
  704. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  705. - the next sentence classification logits of shape [batch_size, 2].
  706. Example usage:
  707. ```python
  708. # Already been converted into WordPiece token ids
  709. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  710. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  711. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  712. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  713. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  714. model = BertForPreTraining(config)
  715. masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  716. ```
  717. """
  718. def __init__(self, config):
  719. super(BertForPreTraining, self).__init__(config)
  720. self.bert = BertModel(config)
  721. self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
  722. self.apply(self.init_bert_weights)
  723. def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
  724. sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
  725. output_all_encoded_layers=False)
  726. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  727. if masked_lm_labels is not None and next_sentence_label is not None:
  728. loss_fct = CrossEntropyLoss(ignore_index=-1)
  729. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
  730. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  731. total_loss = masked_lm_loss + next_sentence_loss
  732. return total_loss
  733. else:
  734. return prediction_scores, seq_relationship_score
  735. class BertForMaskedLM(BertPreTrainedModel):
  736. """BERT model with the masked language modeling head.
  737. This module comprises the BERT model followed by the masked language modeling head.
  738. Params:
  739. config: a BertConfig class instance with the configuration to build a new model.
  740. Inputs:
  741. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  742. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  743. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  744. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  745. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  746. a `sentence B` token (see BERT paper for more details).
  747. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  748. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  749. input sequence length in the current batch. It's the mask that we typically use for attention when
  750. a batch has varying length sentences.
  751. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  752. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  753. is only computed for the labels set in [0, ..., vocab_size]
  754. Outputs:
  755. if `masked_lm_labels` is not `None`:
  756. Outputs the masked language modeling loss.
  757. if `masked_lm_labels` is `None`:
  758. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
  759. Example usage:
  760. ```python
  761. # Already been converted into WordPiece token ids
  762. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  763. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  764. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  765. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  766. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  767. model = BertForMaskedLM(config)
  768. masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
  769. ```
  770. """
  771. def __init__(self, config):
  772. super(BertForMaskedLM, self).__init__(config)
  773. self.bert = BertModel(config)
  774. self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
  775. self.apply(self.init_bert_weights)
  776. def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
  777. sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
  778. output_all_encoded_layers=False)
  779. prediction_scores = self.cls(sequence_output)
  780. if masked_lm_labels is not None:
  781. loss_fct = CrossEntropyLoss(ignore_index=-1)
  782. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
  783. return masked_lm_loss
  784. else:
  785. return prediction_scores
  786. class BertForNextSentencePrediction(BertPreTrainedModel):
  787. """BERT model with next sentence prediction head.
  788. This module comprises the BERT model followed by the next sentence classification head.
  789. Params:
  790. config: a BertConfig class instance with the configuration to build a new model.
  791. Inputs:
  792. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  793. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  794. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  795. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  796. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  797. a `sentence B` token (see BERT paper for more details).
  798. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  799. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  800. input sequence length in the current batch. It's the mask that we typically use for attention when
  801. a batch has varying length sentences.
  802. `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
  803. with indices selected in [0, 1].
  804. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  805. Outputs:
  806. if `next_sentence_label` is not `None`:
  807. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  808. sentence classification loss.
  809. if `next_sentence_label` is `None`:
  810. Outputs the next sentence classification logits of shape [batch_size, 2].
  811. Example usage:
  812. ```python
  813. # Already been converted into WordPiece token ids
  814. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  815. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  816. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  817. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  818. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  819. model = BertForNextSentencePrediction(config)
  820. seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  821. ```
  822. """
  823. def __init__(self, config):
  824. super(BertForNextSentencePrediction, self).__init__(config)
  825. self.bert = BertModel(config)
  826. self.cls = BertOnlyNSPHead(config)
  827. self.apply(self.init_bert_weights)
  828. def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
  829. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
  830. output_all_encoded_layers=False)
  831. seq_relationship_score = self.cls( pooled_output)
  832. if next_sentence_label is not None:
  833. loss_fct = CrossEntropyLoss(ignore_index=-1)
  834. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  835. return next_sentence_loss
  836. else:
  837. return seq_relationship_score
  838. class BertForSequenceClassification(BertPreTrainedModel):
  839. """BERT model for classification.
  840. This module is composed of the BERT model with a linear layer on top of
  841. the pooled output.
  842. Params:
  843. `config`: a BertConfig class instance with the configuration to build a new model.
  844. `num_labels`: the number of classes for the classifier. Default = 2.
  845. Inputs:
  846. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  847. with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
  848. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  849. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  850. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  851. a `sentence B` token (see BERT paper for more details).
  852. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  853. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  854. input sequence length in the current batch. It's the mask that we typically use for attention when
  855. a batch has varying length sentences.
  856. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  857. with indices selected in [0, ..., num_labels].
  858. Outputs:
  859. if `labels` is not `None`:
  860. Outputs the CrossEntropy classification loss of the output with the labels.
  861. if `labels` is `None`:
  862. Outputs the classification logits of shape [batch_size, num_labels].
  863. Example usage:
  864. ```python
  865. # Already been converted into WordPiece token ids
  866. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  867. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  868. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  869. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  870. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  871. num_labels = 2
  872. model = BertForSequenceClassification(config, num_labels)
  873. logits = model(input_ids, token_type_ids, input_mask)
  874. ```
  875. """
  876. def __init__(self, config, num_labels):
  877. super(BertForSequenceClassification, self).__init__(config)
  878. self.num_labels = num_labels
  879. self.bert = BertModel(config)
  880. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  881. self.classifier = nn.Linear(config.hidden_size, num_labels)
  882. self.apply(self.init_bert_weights)
  883. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  884. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
  885. pooled_output = self.dropout(pooled_output)
  886. logits = self.classifier(pooled_output)
  887. if labels is not None:
  888. loss_fct = CrossEntropyLoss()
  889. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  890. return loss
  891. else:
  892. return logits
  893. class BertForMultipleChoice(BertPreTrainedModel):
  894. """BERT model for multiple choice tasks.
  895. This module is composed of the BERT model with a linear layer on top of
  896. the pooled output.
  897. Params:
  898. `config`: a BertConfig class instance with the configuration to build a new model.
  899. `num_choices`: the number of classes for the classifier. Default = 2.
  900. Inputs:
  901. `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  902. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  903. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  904. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  905. with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
  906. and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
  907. `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
  908. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  909. input sequence length in the current batch. It's the mask that we typically use for attention when
  910. a batch has varying length sentences.
  911. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  912. with indices selected in [0, ..., num_choices].
  913. Outputs:
  914. if `labels` is not `None`:
  915. Outputs the CrossEntropy classification loss of the output with the labels.
  916. if `labels` is `None`:
  917. Outputs the classification logits of shape [batch_size, num_labels].
  918. Example usage:
  919. ```python
  920. # Already been converted into WordPiece token ids
  921. input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
  922. input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
  923. token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
  924. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  925. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  926. num_choices = 2
  927. model = BertForMultipleChoice(config, num_choices)
  928. logits = model(input_ids, token_type_ids, input_mask)
  929. ```
  930. """
  931. def __init__(self, config, num_choices):
  932. super(BertForMultipleChoice, self).__init__(config)
  933. self.num_choices = num_choices
  934. self.bert = BertModel(config)
  935. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  936. self.classifier = nn.Linear(config.hidden_size, 1)
  937. self.apply(self.init_bert_weights)
  938. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  939. flat_input_ids = input_ids.view(-1, input_ids.size(-1))
  940. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  941. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  942. _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
  943. pooled_output = self.dropout(pooled_output)
  944. logits = self.classifier(pooled_output)
  945. reshaped_logits = logits.view(-1, self.num_choices)
  946. if labels is not None:
  947. loss_fct = CrossEntropyLoss()
  948. loss = loss_fct(reshaped_logits, labels)
  949. return loss
  950. else:
  951. return reshaped_logits
  952. class BertForTokenClassification(BertPreTrainedModel):
  953. """BERT model for token-level classification.
  954. This module is composed of the BERT model with a linear layer on top of
  955. the full hidden state of the last layer.
  956. Params:
  957. `config`: a BertConfig class instance with the configuration to build a new model.
  958. `num_labels`: the number of classes for the classifier. Default = 2.
  959. Inputs:
  960. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  961. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  962. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  963. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  964. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  965. a `sentence B` token (see BERT paper for more details).
  966. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  967. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  968. input sequence length in the current batch. It's the mask that we typically use for attention when
  969. a batch has varying length sentences.
  970. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
  971. with indices selected in [0, ..., num_labels].
  972. Outputs:
  973. if `labels` is not `None`:
  974. Outputs the CrossEntropy classification loss of the output with the labels.
  975. if `labels` is `None`:
  976. Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
  977. Example usage:
  978. ```python
  979. # Already been converted into WordPiece token ids
  980. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  981. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  982. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  983. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  984. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  985. num_labels = 2
  986. model = BertForTokenClassification(config, num_labels)
  987. logits = model(input_ids, token_type_ids, input_mask)
  988. ```
  989. """
  990. def __init__(self, config, num_labels):
  991. super(BertForTokenClassification, self).__init__(config)
  992. self.num_labels = num_labels
  993. self.bert = BertModel(config)
  994. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  995. self.classifier = nn.Linear(config.hidden_size, num_labels)
  996. self.apply(self.init_bert_weights)
  997. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  998. sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
  999. sequence_output = self.dropout(sequence_output)
  1000. logits = self.classifier(sequence_output)
  1001. if labels is not None:
  1002. loss_fct = CrossEntropyLoss()
  1003. # Only keep active parts of the loss
  1004. if attention_mask is not None:
  1005. active_loss = attention_mask.view(-1) == 1
  1006. active_logits = logits.view(-1, self.num_labels)[active_loss]
  1007. active_labels = labels.view(-1)[active_loss]
  1008. loss = loss_fct(active_logits, active_labels)
  1009. else:
  1010. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1011. return loss
  1012. else:
  1013. return logits
  1014. class BertForQuestionAnswering(BertPreTrainedModel):
  1015. """BERT model for Question Answering (span extraction).
  1016. This module is composed of the BERT model with a linear layer on top of
  1017. the sequence output that computes start_logits and end_logits
  1018. Params:
  1019. `config`: a BertConfig class instance with the configuration to build a new model.
  1020. Inputs:
  1021. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  1022. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  1023. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  1024. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  1025. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  1026. a `sentence B` token (see BERT paper for more details).
  1027. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  1028. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  1029. input sequence length in the current batch. It's the mask that we typically use for attention when
  1030. a batch has varying length sentences.
  1031. `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
  1032. Positions are clamped to the length of the sequence and position outside of the sequence are not taken
  1033. into account for computing the loss.
  1034. `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
  1035. Positions are clamped to the length of the sequence and position outside of the sequence are not taken
  1036. into account for computing the loss.
  1037. Outputs:
  1038. if `start_positions` and `end_positions` are not `None`:
  1039. Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
  1040. if `start_positions` or `end_positions` is `None`:
  1041. Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
  1042. position tokens of shape [batch_size, sequence_length].
  1043. Example usage:
  1044. ```python
  1045. # Already been converted into WordPiece token ids
  1046. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  1047. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  1048. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  1049. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  1050. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  1051. model = BertForQuestionAnswering(config)
  1052. start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
  1053. ```
  1054. """
  1055. def __init__(self, config):
  1056. super(BertForQuestionAnswering, self).__init__(config)
  1057. self.bert = BertModel(config)
  1058. # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
  1059. # self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1060. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1061. self.apply(self.init_bert_weights)
  1062. def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
  1063. sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
  1064. logits = self.qa_outputs(sequence_output)
  1065. start_logits, end_logits = logits.split(1, dim=-1)
  1066. start_logits = start_logits.squeeze(-1)
  1067. end_logits = end_logits.squeeze(-1)
  1068. if start_positions is not None and end_positions is not None:
  1069. # If we are on multi-GPU, split add a dimension
  1070. if len(start_positions.size()) > 1:
  1071. start_positions = start_positions.squeeze(-1)
  1072. if len(end_positions.size()) > 1:
  1073. end_positions = end_positions.squeeze(-1)
  1074. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1075. ignored_index = start_logits.size(1)
  1076. start_positions.clamp_(0, ignored_index)
  1077. end_positions.clamp_(0, ignored_index)
  1078. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1079. start_loss = loss_fct(start_logits, start_positions)
  1080. end_loss = loss_fct(end_logits, end_positions)
  1081. total_loss = (start_loss + end_loss) / 2
  1082. return total_loss
  1083. else:
  1084. return start_logits, end_logits

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/348678
推荐阅读
相关标签
  

闽ICP备14008679号