赞
踩
在前两章中,我们介绍了如何准备数据集以及搭建基本的语言模型架构。现在,我们进入到模型的训练阶段(这一章很简单)
提供鲁迅作品数据集:数据集文件源URL
和MPTS数据集:数据集文件2源URL
(mpts数据集虽然是json格式,但可以不用预处理,直接将带有json语法的数据集拿去训练也行)
本篇博文将通过分析提供的代码片段来介绍如何训练一个语言模型。该代码包含了一些实用的功能,比如损失记录、模型检查点保存、实时绘图展示训练过程中的损失变化等。
确保你已经完成了以下准备工作:
matplotlib
和 tensorflow
)。训练循环是模型训练的核心部分。在我们的代码中,训练循环被组织在一个简单的 for
循环中,它遍历每个 epoch,并在每个 epoch 中遍历数据集的每一个 batch。
for epoch in range(ste, EPOCHS):
for ii, dd in enumerate(dataset):
# ...
在这个循环中,我们首先检查是否需要跳过某些 batch,然后获取当前 batch 的数据并预处理为模型所需的输入格式。接下来,我们调用模型的 train_on_batch
方法来进行一次前向传播和反向传播,更新模型权重以最小化损失函数。
input_sequences, target_sequences = preprocess_data_for_training(dd)#这个函数上一张定义了
loss = model.train_on_batch(input_sequences, target_sequences)#训练1个batch的数据
为了监控训练过程,我们记录每次训练迭代的损失值,并将其可视化显示出来。这有助于我们了解模型的学习情况。
vli.append(loss)
plt.plot(range(1, len(vli) + 1), vli)
plt.show(block=False)
plt.pause(0.05)
在训练过程中,我们通常会定期保存模型的权重,以便在训练中断或完成时能够恢复或评估模型。此外,我们还允许用户通过文件指示来控制模型的保存。
if ii % save_every == 0:
model.save_weights(ncheckpoint_prefix)
elif 'true' in s or '1' in s:
model.save_weights(ncheckpoint_prefix)
在训练过程中,我们还可以定期测试模型的生成能力,这对于验证模型的有效性非常有帮助。
if 'true' in sp or '1' in sp:
m, d = test.load(ckpt_dir=ncheckpoint_dir, model_type=mt)
ret = test.generate_texts_fast(m, d, yw, num_generate=16, ret_ori=False)
print('原文:', yw, '\n\n测试续写:', repr(ret))
在本篇博文中,我们介绍了如何设置训练循环来训练一个语言模型,包括监控训练进度、保存模型检查点、以及进行中间测试。这些步骤对于训练任何深度学习模型都是非常重要的,希望这篇博文能帮助你更好地理解和实践模型训练的过程。
接下来,你可以尝试运行这段代码并观察模型的训练过程。调整超参数如学习率、批次大小等,看看它们如何影响模型的表现。此外,还可以尝试使用不同的数据集或模型架构来进一步提高模型的性能(详见下几章)。
基础模型1-3的完整代码(代码了一下,导入库的步骤和文件路径可以自己加),调用train函数就可以直接训练了:
训练函数以及后面的调用模型推理的函数会使用配置变量dic,格式如下:
代码有问题可以私信联系我,临时调整的格式。我用的tf版本是2.10.1。
def train( mt=3, big_file=False,#是否采用大文件加载策略 #数据集 path_to_file = r'en_novel.txt', ntype_='_en',#保存为微调模型名称 #设置vocab版本 vtype_='_lx',#type_# fen=50,#数据量分几份 fwidx=0,#第几份 BATCH_SIZE = 64, loadmodel=False, pass_=-1, ste=0, ):''' 多出的参数不必理会,后面会用到 ''' global LR,param_data,p_ntype p_ntype=ntype_ if ntype_[0]!='_':ntype_='_'+ntype_ type_=ntype_ print('path_to_file',path_to_file) print('LR',LR) import os #dataset与vocab是配对的! if not os.path.exists(r'dataset/vocab'+vtype_+'.txt'): raise Exception("can't reading vocab from "+r'E:\小思框架\论文\ganskchat\vocab'+vtype_+'.txt') else: with open('dataset/vocab'+vtype_+'.txt','r',encoding='utf-8') as f: vocab=eval(f.read()) UNK=0 unkli=[] char2idx = { u:i for i, u in enumerate(vocab)} idx2char = np.array(vocab) print('{') for char,_ in zip(char2idx, range(20)): print(' {:4s}: {:3d},'.format(repr(char), char2idx[char])) print(' ...\n}') # 设定每个输入句子长度的最大值 seq_length = dic[mt][2] def split_input_target(chunk): input_text = chunk[:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。