当前位置:   article > 正文

KBQA学习记录-NER的main函数_nerprocessor

nerprocessor

目录

一、main函数实现的内容

0.main()函数

1.CrfInputExample类

2.CrfInputFeatures类

3.NERprocessor类

4.类所需函数

①load_and_cache_example

②crf_convert_examples_to_features


一、main函数实现的内容

在main()函数中,主要是对样本的处理,使得我们能够得到能够输入模型训练的数据。需要一些辅助工具的类,提前定义。大概流程如下:

1.通过argparser添加参数

2.实例化NER processor类

3.实例化tokenizer = bertTokenizer(),构建训练数据时用

4.实例化BertCRF模型,训练用

5.获取训练数据

6.训练,所需函数使用“NER训练及验证”文章记录。

0.main()函数

该函数设置了参数,以及引导了整体流程。

  1. def main():
  2. parser = argparse.ArgumentParser()
  3. parser.add_argument("--data_dir", default=None, type=str, required=True,
  4. help="数据文件目录,应当有train.txt dev.txt")
  5. parser.add_argument("--vob_file", default=None, type=str, required=True,
  6. help="词表文件")
  7. parser.add_argument("--model_config", default=None, type=str, required=True,
  8. help="模型配置文件json文件")
  9. parser.add_argument("--output_dir", default=None, type=str, required=True,
  10. help="输出结果的文件")
  11. # Other parameters
  12. parser.add_argument("--pre_train_model", default=None, type=str, required=False,
  13. help="预训练的模型文件,参数矩阵。如果存在就加载")
  14. parser.add_argument("--max_seq_length", default=128, type=int,
  15. help="输入到bert的最大长度,通常不应该超过512")
  16. parser.add_argument("--do_train", action='store_true',
  17. help="是否进行训练")
  18. parser.add_argument("--train_batch_size", default=8, type=int,
  19. help="训练集的batch_size")
  20. parser.add_argument("--eval_batch_size", default=8, type=int,
  21. help="验证集的batch_size")
  22. parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
  23. help="梯度累计更新的步骤,用来弥补GPU过小的情况")
  24. parser.add_argument("--learning_rate", default=5e-5, type=float,
  25. help="学习率")
  26. parser.add_argument("--weight_decay", default=0.0, type=float,
  27. help="权重衰减")
  28. parser.add_argument("--adam_epsilon", default=1e-8, type=float,
  29. help="Epsilon for Adam optimizer.")
  30. parser.add_argument("--max_grad_norm", default=1.0, type=float,
  31. help="最大的梯度更新")
  32. parser.add_argument("--num_train_epochs", default=3.0, type=float,
  33. help="epoch 数目")
  34. parser.add_argument('--seed', type=int, default=42,
  35. help="random seed for initialization")
  36. parser.add_argument("--warmup_steps", default=0, type=int,
  37. help="让学习增加到1的步数,在warmup_steps后,再衰减到0")
  38. args = parser.parse_args()
  39. args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  40. logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
  41. datefmt='%m/%d/%Y %H:%M:%S',
  42. level=logging.INFO)
  43. # filename='./output/bert-crf-ner.log',
  44. processor = NerProcessor()
  45. # 得到tokenizer
  46. tokenizer_inputs = ()
  47. tokenizer_kwards = {'do_lower_case': False,
  48. 'max_len': args.max_seq_length,
  49. 'vocab_file': args.vob_file}
  50. tokenizer = BertTokenizer(*tokenizer_inputs,**tokenizer_kwards)
  51. print(len(processor.get_labels()))
  52. model = BertCrf(config_name= args.model_config,model_name=args.pre_train_model,num_tags = len(processor.get_labels()),batch_first=True)
  53. model = model.to(args.device)
  54. train_dataset = load_and_cache_example(args,tokenizer,processor,'train')
  55. eval_dataset = load_and_cache_example(args,tokenizer,processor,'dev')
  56. test_dataset = load_and_cache_example(args, tokenizer, processor, 'test')
  57. if args.do_train:
  58. trains(args,train_dataset,eval_dataset,model)

1.CrfInputExample类

这个类里面定义的是样本相关的内容,样本id,样本text,样本label,用于后续调用

  1. class CrfInputExample(object):
  2. def __init__(self, guid, text, label=None):
  3. self.guid = guid
  4. self.text = text
  5. self.label = label

2.CrfInputFeatures类

这里面定义的是样本特征的内容,也就是用于输入模型的内容。

  1. class CrfInputFeatures(object):
  2. def __init__(self, input_ids, attention_mask, token_type_ids, label):
  3. self.input_ids = input_ids
  4. self.attention_mask = attention_mask
  5. self.token_type_ids = token_type_ids
  6. self.label = label

上面这两个类定义好,都是为了后面方便调用,直接获取相关内容。

3.NERprocessor类

这个类用来创建训练、验证、测试样本

  1. class NerProcessor(DataProcessor):
  2. def get_train_examples(self,data_dir):
  3. return self._create_examples(
  4. os.path.join(data_dir,"train.txt"))
  5. def get_dev_examples(self, data_dir):
  6. return self._create_examples(
  7. os.path.join(data_dir, "dev.txt"))
  8. def get_test_examples(self, data_dir):
  9. return self._create_examples(
  10. os.path.join(data_dir, "test.txt"))
  11. def get_labels(self):
  12. return CRF_LABELS
  13. @classmethod
  14. def _create_examples(cls, path):
  15. lines = []
  16. max_len = 0
  17. with codecs.open(path, 'r', encoding='utf-8') as f:
  18. word_list = []
  19. label_list = []
  20. for line in f:
  21. tokens = line.strip().split(' ')
  22. if 2 == len(tokens):
  23. word = tokens[0]
  24. label = tokens[1]
  25. word_list.append(word)
  26. label_list.append(label)
  27. elif 1 == len(tokens) and '' == tokens[0]:
  28. if len(label_list) > max_len:
  29. max_len = len(label_list)
  30. lines.append((word_list,label_list))
  31. word_list = []
  32. label_list = []
  33. examples = []
  34. for i,(sentence,label) in enumerate(lines):
  35. examples.append(
  36. CrfInputExample(guid=i,text=" ".join(sentence),label=label)
  37. )
  38. return examples

4.类所需函数

①load_and_cache_example

通过如下函数获取,该函数实现的内容大致是:

如果已经有存好的特征文件,就导入,否则就自己创建特征

创建特征:通过crf_convert_examples_to_features()函数获取特征

处理特征:将特征挨个抽取出来,每个特征都保存成一个列表,并转化为tensor,最后通过TensorDataset给整合起来(torch.utils.data.TensorDataset)

  1. def load_and_cache_example(args,tokenizer,processor,data_type):
  2. type_list = ['train', 'dev', 'test']
  3. if data_type not in type_list:
  4. raise ValueError("data_type must be one of {}".format(" ".join(type_list)))
  5. cached_features_file = "cached_{}_{}".format(data_type, str(args.max_seq_length))
  6. cached_features_file = os.path.join(args.data_dir, cached_features_file)
  7. if os.path.exists(cached_features_file):
  8. features = torch.load(cached_features_file)
  9. else:
  10. label_list = processor.get_labels()
  11. if type_list[0] == data_type:
  12. examples = processor.get_train_examples(args.data_dir)
  13. elif type_list[1] == data_type:
  14. examples = processor.get_dev_examples(args.data_dir)
  15. elif type_list[2] == data_type:
  16. examples = processor.get_test_examples(args.data_dir)
  17. else:
  18. raise ValueError("UNKNOW ERROR")
  19. features = crf_convert_examples_to_features(examples=examples,tokenizer=tokenizer,max_length=args.max_seq_length,label_list=label_list)
  20. logger.info("Saving features into cached file %s", cached_features_file)
  21. torch.save(features, cached_features_file)
  22. all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
  23. all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
  24. all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
  25. all_label = torch.tensor([f.label for f in features], dtype=torch.long)
  26. dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label)
  27. return dataset

②crf_convert_examples_to_features

将输入的样本转为特征,共四个特征

input_id:将序号,文本样本输入tokenizer.encode_plus,会返回多个值,第一个就是我们需要的input_id

attention_mask:根据input_id的长度,创建的全1的列表

token_type_id:将序号,文本样本输入tokenizer.encode_plus,会返回多个值,第二个就是我们需要的token_type_id

label_id:根据输入的标签列表,转成id之后,另外加上bert所需要的分隔符[CLS]等,对应位置可以添加0,因为有mask,后面计算的时候会自动抹除。

  1. def crf_convert_examples_to_features(examples,tokenizer,
  2. max_length=512,
  3. label_list=None,
  4. pad_token=0,
  5. pad_token_segment_id = 0,
  6. mask_padding_with_zero = True):
  7. label_map = {label:i for i, label in enumerate(label_list)}
  8. features = []
  9. for (ex_index, example) in enumerate(examples):
  10. inputs = tokenizer.encode_plus(
  11. example.text,
  12. add_special_tokens=True,
  13. max_length=max_length,
  14. truncate_first_sequence=True # We're truncating the first sequence in priority if True
  15. )
  16. input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
  17. attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
  18. padding_length = max_length - len(input_ids)
  19. input_ids = input_ids + ([pad_token] * padding_length)
  20. attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
  21. token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
  22. # 第一个和第二个[0] 加的是[CLS]和[SEP]的位置, [0]*padding_length是[pad] ,把这些都暂时算作"O",后面用mask 来消除这些,不会影响
  23. labels_ids = [0] + [label_map[l] for l in example.label] + [0] + [0]*padding_length
  24. assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
  25. assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),max_length)
  26. assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids),max_length)
  27. assert len(labels_ids) == max_length, "Error with input length {} vs {}".format(len(labels_ids),max_length)
  28. if ex_index < 5:
  29. logger.info("*** Example ***")
  30. logger.info("guid: %s" % (example.guid))
  31. logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
  32. logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
  33. logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
  34. logger.info("label: %s " % " ".join([str(x) for x in labels_ids]))
  35. features.append(
  36. CrfInputFeatures(input_ids,attention_mask,token_type_ids,labels_ids)
  37. )
  38. return features
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/948578
推荐阅读
相关标签
  

闽ICP备14008679号