当前位置:   article > 正文

书生·浦语大模型实战营:基于InternLM2-chat-7B微调一个Text-to-SQL领域模型

书生·浦语大模型实战营:基于InternLM2-chat-7B微调一个Text-to-SQL领域模型


在参加完书生·浦语大模型实战营后,我打算微调一个Text-to-SQL领域的垂直模型。选择上海人工智能实验室推出的InternLM2-chat-7B模型作为基座模型进行增量训练。

训练阶段

训练平台

AutoDL平台、RTX 4090(24G)、Ubuntu22.04、CUDA 12.1

配置环境

# 创建一个python 3.10的环境
conda create --name xtuner python=3.10 -y
# 激活环境
conda activate xtuner

# 拉取xtuner工具源码
mkdir xtuner && cd xtuner
git clone https://github.com/InternLM/xtuner.git
# 进入源码目录(和我起的文件名重复了)
cd xtuner
# 从源码安装 XTuner
pip install -e '.[all]'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

创建文件夹是放在/root/autodl-tmp/下,该路径是数据盘,之后,我们在/root/autodl-tmp/下新建一个nl2sql文件夹作为工作路径。

数据集

使用DB-GPT处理并在Hugging Face开源的数据集,经过筛除掉多轮对话数据以及整理格式后得到19,5297条数据。
DB-GPT-Hub:https://github.com/eosphoros-ai/DB-GPT-Hub
数据集:https://huggingface.co/datasets/Healthy13/Text2SQL
处理后的格式如下:

[
  {
    "question": "which states border arizona",
    "context": "CREATE TABLE mountain (mountain_name, mountain_altitude, state_name, country_name); CREATE TABLE city (city_name, state_name, population, country_name); CREATE TABLE road (road_name, state_name); CREATE TABLE border_info (state_name, border); CREATE TABLE river (river_name, length, traverse, country_name); CREATE TABLE state (state_name, capital, population, area, country_name, density); CREATE TABLE highlow (state_name, highest_point, highest_elevation, lowest_point, lowest_elevation); CREATE TABLE lake (lake_name, area, state_name, country_name)",
    "answer": "SELECT border FROM border_info WHERE state_name = 'arizona'"
  },
  ...
  {}
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

处理成这样的好处是可以直接使用xtuner config中提前设置好的模板中关于sql数据集的映射文件

from xtuner.utils import SYSTEM_TEMPLATE


def sql_map_fn(example):
    return {
        'conversation': [{
            'system': SYSTEM_TEMPLATE.sql,
            'input': '{context}\n{question}'.format(**example),
            'output': example['answer']
        }]
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

模型下载

python ./model_download.py

import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
model_dir = snapshot_download('Shanghai_AI_Laboratory/internlm2-chat-7b', cache_dir='/root/autodl-tmp/nl2sql')
  • 1
  • 2
  • 3
  • 4

我们需要将用来微调的基座模型下载到本地

微调

XTuner提供多个开箱即用的配置文件,可以通过下列命令查看:

# 列出所有内置配置
xtuner list-cfg
  • 1
  • 2

但是我发现给出的配置中没有找到关于internlm2_chat_7b_…关于sql的py文件,于是我选择针对internlm_chat_7b_qlora_sql_e3.py文件进行修改:-表示删除 + 表示增加

# Model
- pretrained_model_name_or_path = 'internlm/internlm2-7b'
+ pretrained_model_name_or_path = './Shanghai_AI_Laboratory/internlm2-chat-7b' //模型加载地址换成本地下载好的模型
use_varlen_attn = False

# Data
- data_path = 'b-mc2/sql-create-context'
+ data_path = './dataset/single_multi_text2sql_xtuner.json' //训练所需的数据集换成本地数据集
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 2048
pack_to_max_length = True

train_dataset = dict(
    type=process_hf_dataset,
    - dataset=dict(type=load_dataset, path=data_path),
    + dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
    tokenizer=tokenizer,
    max_length=max_length,
    dataset_map_fn=sql_map_fn,
    template_map_fn=dict(
        type=template_map_fn_factory, template=prompt_template),
    remove_unused_columns=True,
    shuffle_before_pack=True,
    pack_to_max_length=pack_to_max_length,
    use_varlen_attn=use_varlen_attn)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

修改好该配置文件后,通过xtuner train命令开始训练,并开启deepspeed加速

xtuner train ./internlm2_7b_qlora_sql_e3_copy.py --deepspeed deepseed_zero2
  • 1

将得到的PTH模型转换为HuggingFace模型,即:生成Adapter文件

mkdir hf
export MKL_SERVICE_FORCE_INTEL=1
xtuner convert pth_to_hf ./internlm2_7b_qlora_sql_e3_copy.py ./work_dirs/internlm2_7b_qlora_sql_e3_copy/epoch_3.pth ./hf
  • 1
  • 2
  • 3

将HuggingFace Adapter合并到基座模型

xtuner convert merge ./Shanghai_AI_Laboratory/internlm2-chat-7b ./hf ./merged --max-shard-size 2GB
  • 1

此时,/root/autodl-tmp/nl2sql/路径下文件目录如下:

├── Shanghai_AI_Laboratory
│   └── internlm2-chat-7b
│       ├── README.md
│       ├── config.json
│       ├── configuration.json
│       ├── configuration_internlm2.py
│       ├── generation_config.json
│       ├── modeling_internlm2.py
│       ├── pytorch_model-00001-of-00008.bin
│       ├── pytorch_model-00002-of-00008.bin
│       ├── pytorch_model-00003-of-00008.bin
│       ├── pytorch_model-00004-of-00008.bin
│       ├── pytorch_model-00005-of-00008.bin
│       ├── pytorch_model-00006-of-00008.bin
│       ├── pytorch_model-00007-of-00008.bin
│       ├── pytorch_model-00008-of-00008.bin
│       ├── pytorch_model.bin.index.json
│       ├── special_tokens_map.json
│       ├── tokenization_internlm2.py
│       ├── tokenization_internlm2_fast.py
│       ├── tokenizer.model
│       └── tokenizer_config.json
├── dataset
│   └── single_multi_text2sql_xtuner.json
├── hf
│   ├── README.md
│   ├── adapter_config.json
│   ├── adapter_model.bin
│   └── xtuner_config.py
├── internlm2_7b_qlora_sql_e3_copy.py
├── merged
│   ├── config.json
│   ├── configuration_internlm2.py
│   ├── generation_config.json
│   ├── modeling_internlm2.py
│   ├── pytorch_model-00001-of-00008.bin
│   ├── pytorch_model-00002-of-00008.bin
│   ├── pytorch_model-00003-of-00008.bin
│   ├── pytorch_model-00004-of-00008.bin
│   ├── pytorch_model-00005-of-00008.bin
│   ├── pytorch_model-00006-of-00008.bin
│   ├── pytorch_model-00007-of-00008.bin
│   ├── pytorch_model-00008-of-00008.bin
│   ├── pytorch_model.bin.index.json
│   ├── special_tokens_map.json
│   ├── tokenization_internlm2.py
│   ├── tokenization_internlm2_fast.py
│   ├── tokenizer.json
│   ├── tokenizer.model
│   └── tokenizer_config.json
├── model_download.py
└── work_dirs
    └── internlm2_7b_qlora_sql_e3_copy
        ├── 20240311_092740
        │   ├── 20240311_092740.log
        │   └── vis_data
        │       ├── 20240311_092740.json
        │       ├── config.py
        │       └── scalars.json
        ├── 20240311_093606
        │   ├── 20240311_093606.log
        │   └── vis_data
        │       ├── 20240311_093606.json
        │       ├── config.py
        │       └── scalars.json
        ├── 20240311_093944
        │   ├── 20240311_093944.log
        │   └── vis_data
        │       ├── 20240311_093944.json
        │       ├── config.py
        │       └── scalars.json
        ├── 20240311_094125
        │   ├── 20240311_094125.log
        │   └── vis_data
        │       ├── 20240311_094125.json
        │       ├── config.py
        │       └── scalars.json
        ├── epoch_1.pth
        │   ├── bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
        │   └── mp_rank_00_model_states.pt
        ├── epoch_2.pth
        │   ├── bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
        │   └── mp_rank_00_model_states.pt
        ├── epoch_3.pth
        │   ├── bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
        │   └── mp_rank_00_model_states.pt
        ├── internlm2_7b_qlora_sql_e3_copy.py
        ├── last_checkpoint
        └── zero_to_fp32.py
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

使用Xtuner chat进行验证

# 加载 Adapter 模型对话(Float 16)
xtuner chat ./merged --prompt-template internlm2_chat

# 4 bit 量化加载
xtuner chat ./merged --bits 4 --prompt-template internlm2_chat
  • 1
  • 2
  • 3
  • 4
  • 5

至此,微调任务结束,点击前往模型下载地址

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/263040
推荐阅读
相关标签
  

闽ICP备14008679号