赞
踩
主要分析一下预训练过程(pretrain_glm.py)
首先是参数设置
# Arguments. args = get_args() args.mem_length = args.mem_length if args.transformer_xl else 0 if args.load and not args.new_save_directory: args.experiment_name = os.path.basename(os.path.normpath(args.load)) else: args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M") if args.save: args.save = os.path.join(args.save, args.experiment_name) # Pytorch distributed. # 设置初始化参数 initialize_distributed(args) # Random seeds for reproducability. # 随机数种子 set_random_seed(args.seed)
接下来准备tokenizer,提供了Bert、GPT、中文编码三种方式,分别对应了BertWordPieceTokenizer、GPT2BPETokenizer 和ChineseSPTokenizer。
然后通过get_train_val_test_data()函数得到训练集、验证集以及测试集。
def get_train_val_test_data(args, tokenizer): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: data_config = configure_data() if args.block_lm: data_set_type = "Block" elif args.transformer_xl: data_set_type = "GPT-XL" else: data_set_type = "GPT2" data_config.set_defaults(data_set_type=data_set_type, transpose=False) train_data, val_data, test_data = data_config.apply(args, tokenizer) data_counts = torch.cuda.LongTensor([int(args.do_train), int(args.do_valid), int(args.do_test)]) else: data_counts = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. torch.distributed.broadcast(data_counts, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) args.do_train = data_counts[0].item() args.do_valid = data_counts[1].item() args.do_test = data_counts[2].item() return train_data, val_data, test_data
这里的数据集获取主要依托于data_config.apply(),其中包含了make_loaders()函数
def make_loaders(args, tokenizer): """makes training/val/test""" if args.use_tfrecords: return make_tfrecord_loaders(args) world_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) if args.loader_scatter is not None: assert world_size % args.loader_scatter == 0 # batch_size以及句子长度设置 batch_size = args.batch_size * world_size eval_batch_size = batch_size if args.eval_batch_size is not None: eval_batch_size = args.eval_batch_size * world_size seq_length = args.seq_length if seq_length < 0: seq_length = seq_length * world_size eval_seq_length = args.eval_seq_length if eval_seq_length is not None and eval_seq_length < 0: eval_seq_length = eval_seq_length * world_size split = get_split(args) # 数据参数设置 data_set_args = { 'path': args.train_data, 'seq_length': seq_length, 'mem_length': args.mem_length, 'delim': args.delim, 'text_key': args.text_key, 'label_key': 'label', 'ds_type': args.data_set_type, 'split': split, 'loose': args.loose_json, 'max_preds_per_seq': args.max_preds_per_seq, 'presplit_sentences': args.presplit_sentences, 'sample_one_document': args.sample_one_document, 'filter_english': args.filter_english, 'pre_tokenize': not args.no_pre_tokenize, 'tokenizer': tokenizer, 'save_splits': args.save_splits, 'load_splits': args.load_splits, 'save_test_data': args.save_test_data, 'no_lazy_loader': args.no_lazy_loader, 'loader_scatter': args.loader_scatter, 'data_parallel_rank': mpu.get_data_parallel_rank(), "non_sentence_start": args.non_sentence_start, "half_lazy_loader": args.half_lazy_loader } eval_set_args = copy.copy(data_set_args) eval_set_args['split'] = [1.] # if optional eval args were set then replace their # equivalent values in the arg dict if eval_seq_length: eval_set_args['seq_length'] = eval_seq_length if args.eval_max_preds_per_seq: eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq if args.eval_text_key is not None: eval_set_args['text_key'] = args.eval_text_key # make datasets splits and tokenizer train, valid, test = None, None, None if args.train_data is not None: # 构建训练数据集 train = data_utils.make_dataset(**data_set_args) if data_utils.should_split(split): train, valid, test = train eval_set_args['tokenizer'] = tokenizer # make training and val dataset if necessary if valid is None and args.valid_data is not None: eval_set_args['path'] = args.valid_data valid = data_utils.make_dataset(**eval_set_args) eval_set_args['tokenizer'] = tokenizer if test is None and args.test_data is not None: eval_set_args['path'] = args.test_data test = data_utils.make_dataset(**eval_set_args) # wrap datasets with data loader use_block = args.block_lm or args.encoder_decoder if train is not None and args.batch_size > 0: train = make_data_loader(train, tokenizer, batch_size, args.train_iters, args, shuffle=args.shuffle, block_collate=use_block) args.do_train = True else: args.do_train = False eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size if valid is not None: valid = make_data_loader(valid, tokenizer, eval_batch_size, args.train_iters, args, shuffle=args.shuffle, block_collate=use_block) args.do_valid = True else: args.do_valid = False if test is not None: test = make_data_loader(test, tokenizer, eval_batch_size, len(test) // eval_batch_size + 1, args, shuffle=args.shuffle, block_collate=use_block) args.do_test = True else: args.do_test = False return train, valid, test
其中,进一步调用了data_utils.make_dataset(),进而调用get_dataset(),其中包含了对MASK和跨度的设计,主要思路如下图所示
def get_dataset(name, tokenizer, pre_tokenize, data_parallel_rank, loader_scatter=None, no_lazy_loader=False, half_lazy_loader=False): """gets dataset object based on keyword args and file at `path`""" global_rank = torch.distributed.get_rank() if not supported_corpus(name): raise NotImplementedError('dataset %s is not supported' % name) dataset = corpora.NAMED_CORPORA[name] path = dataset.PATH if issubclass(dataset, corpora.PromptReader): if not (exists_lazy(path, data_type='prompt') and exists_lazy(path, data_type='text')) and not ( loader_scatter is not None and exists_scatter(path, data_type='prompt', scatter_num=loader_scatter) and exists_scatter(path, data_type='text', scatter_num=loader_scatter)): # create cached version of dataset for lazy loading if it doesn't exist if global_rank == 0: print(f"Creating lazy loader for dataset {name}") prompt_writer = LazyWriter(path, data_type='prompt', is_array=pre_tokenize) text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize) writers = {'prompt': prompt_writer, 'text': text_writer} reader = dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize) reader.process() prompt_writer.close() text_writer.close() else: while not os.path.exists(LazyWriter.get_len_path(path, data_type='prompt')): time.sleep(1) map_fn = (lambda x: x.tolist()) if pre_tokenize else None if loader_scatter is not None: if not (exists_scatter(path, data_type='prompt', scatter_num=loader_scatter) and exists_scatter(path, data_type='text', scatter_num=loader_scatter)): if global_rank == 0: print(f"Creating scatter loader for dataset {name}") prompts = LazyLoader(path, data_type='prompt', map_fn=map_fn, mem_map=True, is_array=pre_tokenize) texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize) indices = list(range(len(texts))) random.shuffle(indices) segment_length = (len(indices) - 1) // loader_scatter + 1 for i in range(loader_scatter): scatter_path = get_scatter_path(path, scatter_rank=i) prompt_writer = LazyWriter(scatter_path, data_type='prompt', is_array=pre_tokenize) text_writer = LazyWriter(scatter_path, data_type='text', is_array=pre_tokenize) for idx in indices[i * segment_length: (i + 1) * segment_length]: prompt_writer.write(prompts[idx]) text_writer.write(texts[idx]) prompt_writer.close() text_writer.close() else: while not ( exists_scatter(path, data_type='prompt', scatter_num=loader_scatter) and exists_scatter( path, data_type='text', scatter_num=loader_scatter)): time.sleep(1) scatter_path = get_scatter_path(path, scatter_rank=data_parallel_rank % loader_scatter) print(f"Rank {global_rank} is using scatter from {scatter_path}") prompts = LazyLoader(scatter_path, data_type='prompt', map_fn=map_fn, mem_map=True, is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader) texts = LazyLoader(scatter_path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader) else: prompts = LazyLoader(path, data_type='prompt', map_fn=map_fn, mem_map=True, is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader) texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize, load_memory=no_lazy_loader, half_load=half_lazy_loader) text = corpora.PromptDataset(prompt_loader=prompts, text_loader=texts, tokenizer=tokenizer, to_tokenize=not pre_tokenize) if loader_scatter is None: if global_rank == 0: print(f"Create dataset {name} with {len(text)} documents") for i in range(10): rand_id = i if i < 5 else random.randrange(len(text)) sample_tokens = text[rand_id]['tokens'][:1024] print(sample_tokens) print(tokenizer.DecodeIds(sample_tokens).encode('utf-8')) else: for scatter_id in range(loader_scatter): if data_parallel_rank % loader_scatter == scatter_id and data_parallel_rank // loader_scatter == 0: print(f"Create dataset {name} at scatter {scatter_id} with {len(text)} documents") for i in range(10): sample_tokens = text[i]['tokens'][:1024] print(sample_tokens) print(tokenizer.DecodeIds(sample_tokens)) torch.distributed.barrier() return text elif issubclass(dataset, corpora.KeyReader): if not (exists_lazy(path, data_type='text') and exists_lazy(path, data_type='mask')): # create cached version of dataset for lazy loading if it doesn't exist if global_rank == 0: text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize) mask_writer = LazyWriter(path, data_type='mask', is_array=True) writers = {'mask': mask_writer, 'text': text_writer} dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize) mask_writer.close() text_writer.close() else: while not os.path.exists(LazyWriter.get_len_path(path, data_type='mask')): time.sleep(1) map_fn = (lambda x: x.tolist()) if pre_tokenize else None masks = LazyLoader(path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True) texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize) text = corpora.KeyDataset(mask_loader=masks, text_loader=texts, tokenizer=tokenizer, to_tokenize=not pre_tokenize) return text
这样数据集就构建好了,如果是多任务场景,则构建多任务数据集
# 得到多任务数据集的内容
if args.multi_task_ratio > 0.0:
multi_train_data, multi_val_data = build_multi_task_dataset(args, tokenizer)
接下来会判断是否需要导入预训练模型,并且对日志和记录进行保存
# 是否导入预训练模型 if args.load is not None: with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) else: args.iteration = 0 torch.distributed.barrier() if args.switch_linear: lr_scheduler.switch_linear(args) summary_writer = None if torch.distributed.get_rank() == 0: print('Pretrain GPT2 model') args.log_dir = None # 保存日志和输出 if args.train_iters > 0: args.log_dir = get_log_dir(base=args.summary_dir, name=args.experiment_name) summary_writer = get_sample_writer(log_dir=args.log_dir, iteration=args.iteration) print_and_save_args(args, verbose=True, log_dir=args.log_dir)
如果导入的模型已经训练过了,构建继续训练的数据集,和上面构建数据集的主要区别在于对开始迭代次数的处理
# Resume data loader if necessary. if args.resume_dataloader: print_rank_0("Resume dataloader") if train_data is not None: train_data.batch_sampler.start_iter = args.iteration % len(train_data) if val_data is not None: start_iter_val = (args.iteration // args.eval_interval) * args.eval_iters val_data.batch_sampler.start_iter = start_iter_val % len(val_data) if multi_train_data is not None: multi_train_data.batch_sampler.start_iter = int(args.iteration * args.multi_task_ratio) % len( multi_train_data) if multi_val_data is not None: start_iter_val = (args.iteration // args.eval_interval) * args.eval_iters * args.multi_task_ratio multi_val_data.batch_sampler.start_iter = start_iter_val % len(multi_val_data) if train_data is not None: train_data_iterator = iter(train_data) else: train_data_iterator = None if multi_train_data is not None: multi_train_iterator = iter(multi_train_data) else: multi_train_iterator = None if val_data is not None: val_data_iterator = iter(val_data) else: val_data_iterator = None if multi_val_data is not None: multi_val_iterator = iter(multi_val_data) else: multi_val_iterator = None
然后通过train函数进行训练
iteration, skipped = train(model, optimizer,
lr_scheduler,
(train_data_iterator, multi_train_iterator),
(val_data_iterator, multi_val_iterator),
timers, args, summary_writer=summary_writer)
其中,主要通过train_step()函数进行训练,其核心是forward_step()函数
def forward_step(data_iterator, model, args, timers, mems): """Forward step.""" # Get the batch. timers('batch generator').start() timers('data loader').start() rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() + mpu.get_data_parallel_rank()) if data_iterator[1] and rand.random() < args.multi_task_ratio: data = next(data_iterator[1]) if data_iterator[1] else None data["mode"] = "multi-task" else: data = next(data_iterator[0]) if data_iterator[0] else None # print_rank_0("data iterator") timers('data loader').stop() tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data, args) timers('batch generator').stop() # print_rank_0("get batch") def print_masked_text(batch_id): block_position_ids = position_ids[:, 1] position_ids_ = position_ids[:, 0] sep = attention_mask.item() if torch.numel(attention_mask) == 1 else attention_mask[batch_id].item() text, last_segment = "", [] for i, token_id in enumerate(tokens[batch_id, :sep].tolist()): token = tokenizer.IdToToken(token_id) if token.startswith('[MASK') or token.endswith('MASK]'): if last_segment: text += tokenizer.DecodeIds(last_segment) last_segment = [] text += f" [{position_ids_[batch_id, i].item()}, {token}]" else: last_segment.append(token_id) if last_segment: text += tokenizer.DecodeIds(last_segment) print(text.encode('utf-8')) last_index = None for i in range(sep, tokens.size(1)): if tokenizer.IdToToken(tokens[batch_id, i].item()).startswith("<|startofpiece"): if last_index is not None: print(tokenizer.DecodeIds(tokens[batch_id, last_index: i].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds(labels[batch_id, last_index: i].tolist()).encode('utf-8'), position_ids_[batch_id, last_index: i].tolist(), block_position_ids[batch_id, last_index:i].tolist()) last_index = i if last_index is not None: print(tokenizer.DecodeIds(tokens[batch_id, last_index:].tolist()).encode('utf-8'), "|", tokenizer.DecodeIds(labels[batch_id, last_index:].tolist()).encode('utf-8'), position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist()) if data is not None and "mode" in data: mode = data['mode'] else: mode = 'bert' logits, *mems = model(tokens, position_ids, attention_mask, *mems) # 损失函数定义为交叉熵损失 losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) if loss_mask.sum().item() > 0: loss = loss / loss_mask.sum() return loss, mems, mode
最后就是模型的保存以及评估
# Checkpointing
if args.save and args.save_interval and args.iteration % args.save_interval == 0:
save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(args.iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, verbose=False, step=args.iteration,
summary_writer=summary_writer, forward_step_func=forward_step)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。