赞
踩
预训练数据的预处理代码文件:
create_pretraining_data.py
功能:
在这个py文件中,主要功能是生成训练数据
具体的训练命令如下所示:
- python create_pretraining_data.py \
-
- --input_file=./sample_text.txt \
-
- --output_file=/tmp/tf_examples.tfrecord \
-
- --vocab_file=$BERT_BASE_DIR/vocab.txt \
-
- --do_lower_case=True \
-
- --max_seq_length=128 \
-
- --max_predictions_per_seq=20 \
-
- --masked_lm_prob=0.15 \
-
- --random_seed=12345 \
-
- --dupe_factor=5

在上面的命令行中,sample_text.txt是谷歌提供的一个小的训练样本,将这个小的训练样本经过一系列的处理,输出到tf_examples.tfrecord中
sample_text.txt:在这个文本中,空行前后代表不同的文章,每一行代表一句话
在函数的开始部分进行了相关参数的设置
- flags = tf.flags
-
- FLAGS = flags.FLAGS
-
- flags.DEFINE_string("input_file", None,
- "Input raw text file (or comma-separated list of files).")
-
- flags.DEFINE_string(
- "output_file", None,
- "Output TF example file (or comma-separated list of files).")
-
- flags.DEFINE_string("vocab_file", None,
- "The vocabulary file that the BERT model was trained on.")
-
- flags.DEFINE_bool(
- "do_lower_case", True,
- "Whether to lower case the input text. Should be True for uncased "
- "models and False for cased models.")
-
- flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
-
- flags.DEFINE_integer("max_predictions_per_seq", 20,
- "Maximum number of masked LM predictions per sequence.")
-
- flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
-
- flags.DEFINE_integer(
- "dupe_factor", 10,
- "Number of times to duplicate the input data (with different masks).")
-
- flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
-
- flags.DEFINE_float(
- "short_seq_prob", 0.1,
- "Probability of creating sequences which are shorter than the "
- "maximum length.")

在代码中相关参数的解释:
input_file:输入文件路径
output_file:输出文件路径
vocab_file:谷歌提供的词典,值为词典的路径
do_lower_case:当值为True时,则忽略大小写
max_seq_length:每一条训练数据(两句话)相加后的最大长度限制
max_predictions_per_seq:每一条训练数据mask的最大数量
random_seed:一个随机种子
dupe_factor:对文档多次重复随机产生训练集,随机的次数
masked_lm_prob:一条训练数据产生mask的概率,即每条训练数据随机产生max_predictions_per_seq×masked_lm_prob数量的mask
short_seq_prob:为了缩小预训练和微调过程的差距,以此概率产生小于max_seq_length的训练数据
首先获取输入文本,对输入文本创建训练实例,再进行输出,创建实例的函数是create_training_instances
在main()函数中,有一个FullTokenizer类,这个类的主要作用是将词转换成对应的id,参照的是字典vocab_file,但对一些特殊的词需要进行最大长度的拆分,如johanson,这个单词在字典中是没有的,但是johan和##son在字典中,则将johanson拆分成两个词,即johan和##son
- def main(_):
- tf.logging.set_verbosity(tf.logging.INFO)
-
- tokenizer = tokenization.FullTokenizer(
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
- # 创建tokenizer,很多人也许会困惑这个啥,这是Google AI Language Team写的一个字符处理的工具,按照代码里的使用就行
-
- input_files = []
- for input_pattern in FLAGS.input_file.split(","):
- input_files.extend(tf.gfile.Glob(input_pattern)) # #获得输入文件列表
- # tf.gfile.Glob()查找匹配pattern的文件并以列表的形式返回,filename可以是一个具体的文件名,也可以是包含通配符的正则表达式
-
- tf.logging.info("*** Reading from input files ***")
- for input_file in input_files:
- tf.logging.info(" %s", input_file)
-
- rng = random.Random(FLAGS.random_seed)
- instances = create_training_instances(
- input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
- FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
- rng)
-
- output_files = FLAGS.output_file.split(",")
- tf.logging.info("*** Writing to output files ***")
- for output_file in output_files:
- tf.logging.info(" %s", output_file)
-
- write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
- FLAGS.max_predictions_per_seq, output_files) # 输出

在main()函数中主要包括创建实例的create_training_instances()函数,以及输出函数write_instance_to_example_files(),下文会一一进行介绍
在这个函数中,先将文章的每个句子加到二维列表中,再将列表传入create_instances_from_document()函数生成训练实例
返回值:instances 一个列表 里面包含每个样例的TrainingInstance类
- def create_training_instances(input_files, tokenizer, max_seq_length,
- dupe_factor, short_seq_prob, masked_lm_prob,
- max_predictions_per_seq, rng):
- """Create `TrainingInstance`s from raw text."""
- all_documents = [[]]
-
- for input_file in input_files:
- with tf.gfile.GFile(input_file, "r") as reader:
- while True:
- line = tokenization.convert_to_unicode(reader.readline())
- if not line:
- break
- line = line.strip()
-
- # Empty lines are used as document delimiters
- if not line:
- all_documents.append([])
- tokens = tokenizer.tokenize(line) # 官方代码这里是这么处理每一行英文数据的,实际上可以简单理解为做了个分词操作吧
- if tokens:
- all_documents[-1].append(tokens) # 二维列表 [文章,句子]
-
- # Remove empty documents
- all_documents = [x for x in all_documents if x] # 删除空列表
- rng.shuffle(all_documents) # 随机排序
-
- vocab_words = list(tokenizer.vocab.keys())
- instances = []
- for _ in range(dupe_factor): # 对于一份数据,可以每次将masked 设定的位置都不一样,也就是可以做个数据扩充,代码中的dupe_factor就是将数据重复多次进行处理
- for document_index in range(len(all_documents)):
- instances.extend(
- create_instances_from_document(
- all_documents, document_index, max_seq_length, short_seq_prob,
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
-
- rng.shuffle(instances)
- return instances

(1)读取文本,按行分词处理后存储到all_documents中,里面存储的格式为[doc0,doc1,doc2,doc3,...],里面的每一个doc存储的是一个list,如doc1=[line0,line1,line2,lin3,...],同样的,每一个line里存储的也是一个list,如line1=[token0,token1,token2,token3,...],token表示的是一个个的词,之后对文章做shuffle处理
- all_documents = [[]]
-
- for input_file in input_files:
- with tf.gfile.GFile(input_file, "r") as reader:
- while True:
- line = tokenization.convert_to_unicode(reader.readline())
- if not line:
- break
- line = line.strip()
-
- # Empty lines are used as document delimiters
- if not line:
- all_documents.append([])
- tokens = tokenizer.tokenize(line) # 官方代码这里是这么处理每一行英文数据的,实际上可以简单理解为做了个分词操作吧
- if tokens:
- all_documents[-1].append(tokens) # 二维列表 [文章,句子]
-
- # Remove empty documents
- all_documents = [x for x in all_documents if x] # 删除空列表
- rng.shuffle(all_documents) # 随机排序

(2)重复dupe_factor=10次,每篇文章生成样本,[CLS+A+SEP+B+SEP]作为一条样本
- vocab_words = list(tokenizer.vocab.keys())
- instances = []
- for _ in range(dupe_factor): # 对于一份数据,可以每次将masked 设定的位置都不一样,也就是可以做个数据扩充,代码中的dupe_factor就是将数据重复多次进行处理
- for document_index in range(len(all_documents)):
- instances.extend(
- create_instances_from_document(
- all_documents, document_index, max_seq_length, short_seq_prob,
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
-
- rng.shuffle(instances)
在这个函数中,生成训练数据的具体过程,对每条数据生成TrainingInstance,这里的每条数据其实包含两个句子的信息,TrainingInstance包括tokens:词
segement_ids:句子编码,第一句为0,第二句为1
is_random_next:第二句是随机查找,还是未第一句的下文
masked_lm_positions:tokens中被mask的位置
masked_lm_labels:tokens中被mask的原来的词
返回值:instances
create_instances_from_document()函数对每篇文章都生成一个训练样本实例
从第一条句子循环到最后一条句子ii,收集segment到current_chunk列表中,当收集到的总句子长度>=单条样本最长值时,构造A+B
if i == len(document) - 1 or current_length >= target_seq_length:
随机截取 current_chunk的某个位置a_end,[0, a_end]作为子句A=token_a。
B句随机概率选择是Next or Not next,如果是next,则current_chunk的剩余[a_end, :]作为子句B=token_b。如果Not next,则随机挑一篇文章,选择某个长度的子句作为B=token_b。
- num_unused_segments = len(current_chunk) - a_end
- i -= num_unused_segments
两个句子加和长度超过最大长度怎么办?使用truncate_seq_pair在A和B中随机选择一个,随机丢掉首/尾的词,每次丢一个token,直到加和长度<=最大长度。
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
之后根据token_a和token_b生成tokens和segment_ids
tokens = [CLS, A_0, A_1, A_2, SEP, B_0, B_1, B_2, SEP]tokens=[CLS,A0,A1,A2,SEP,B0,B1,B2,SEP]
segment\_ids =[0_a, 0_a, 0_a, 0_a, 0_a, 1_b, 1_b, 1_b, 1_b]segment_ids=[0a,0a,0a,0a,0a,1b,1b,1b,1b]
再之后,根据tokens生成遮挡之后的tokens、遮挡位置masked_lm_positions、遮挡位置的真实词masked_lm_labels。
- (tokens, masked_lm_positions,
- masked_lm_labels) = create_masked_lm_predictions(
- tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
15%采样遮挡,对遮挡的处理情况如下:
a) 80%的概率,遮挡词被替换为[mask]。\longrightarrow⟶别人看不到我。
b) 10%的概率,遮挡词被替换为随机词。\longrightarrow⟶别人看走眼我。
c) 10%的概率,遮挡词被替换为原来词。\longrightarrow⟶别人能看到我。
- masked_token = None
- # 80% of the time, replace with [MASK]
- if rng.random() < 0.8:
- masked_token = "[MASK]"
- else:
- # 10% of the time, keep original
- if rng.random() < 0.5:
- masked_token = tokens[index]
- # 10% of the time, replace with random word
- else:
- masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
输入和返回结果举例:
input tokens ="The man went to the store . He bought a gallon of milk "
ouput tokens ="The man went to the [mask] . He [mask] a gallon of milk"
output masked_lm_positions = [5, 8, 10, 12]
output masked_lm_labels = [store, bought, gallon, ice]
位置#5,#8被遮挡,#10被替换为原token,#12被替换为随机词。注意CLS和SEP不会被遮挡。
然后保存成TrainingInstance类,同时保留了is_next标记.
- instance = TrainingInstance(
- tokens=tokens,
- segment_ids=segment_ids,
- is_random_next=is_random_next,
- masked_lm_positions=masked_lm_positions,
- masked_lm_labels=masked_lm_labels)
tokenization.FullTokenizer类用来处理分词,标点符号,unknown词,Unicode转换等操作。注意:中文只有单个字的切分,没有词。
存储为TF-Record
输入sentence变量的处理
- input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) ## ID化 ##
- input_mask = [1] * len(input_ids)
- segment_ids = segment_ids
- padding 0 --> max_seq_length
1. 对iput_ids 补0到句子最大长度
2. 对input_mask 补0到句子最大长度
3. 对segment_ids 补0到句子最大长度
注意:input_mask是样本中有效词句的标识,后面需要用作作attention视野的约束。
遮挡变量的处理
- masked_lm_positions = list(instance.masked_lm_positions)
- masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
- masked_lm_weights = [1.0] * len(masked_lm_ids)
- ## padding 0 --> max_seq_length
注意:masked_lm_ids是有mask的词对应的ID;masked_lm_positions是有mask的词对应的句子中位置。
next_sentense 处理
next_sentence_label = 1 if instance.is_random_next else 0
save format 处理
- features = collections.OrderedDict()
- features["input_ids"] = create_int_feature(input_ids)
- features["input_mask"] = create_int_feature(input_mask)
- features["segment_ids"] = create_int_feature(segment_ids)
- features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
- features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
- features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
- features["next_sentence_labels"] = create_int_feature([next_sentence_label])
-
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
读取使用dataset。
- input_ids = features["input_ids"]
- input_mask = features["input_mask"]
- segment_ids = features["segment_ids"]
- masked_lm_positions = features["masked_lm_positions"]
- masked_lm_ids = features["masked_lm_ids"]
- masked_lm_weights = features["masked_lm_weights"]
- next_sentence_labels = features["next_sentence_labels"]
下面的几个网址是参考的网址
http://www.manongjc.com/article/30232.html
https://blog.csdn.net/weixin_39470744/article/details/84619903
后续会继续进行更新
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。