当前位置:   article > 正文

使用 QLoRA 在 Google Colab 中微调 Mistral 7b(完整指南)_mistral7b微调

mistral7b微调

使用 QLoRA 在 Google Colab 中微调 Mistral 7b(完整指南)

在本文中,我们将在一个名为 Enlighten 的游戏的整个代码库上微调 Mistral 7b,所有这些都在 Google Colab(或 Kaggle)中免费提供合成数据。在我们的基准测试中,由此产生的模型将优于 Openai 的 GPT-4。

步骤如下:

  • 将代码库转换为基于合成对话的训练测试数据集
  • 使用 QLoRA 进行微调
  • 评估新模型
  • 评估基础模型 + GPT-4
  • (可选)将适配器与基本型号合并
  • (可选)量化模型并获取 GGUF 格式

本文中使用的所有代码和数据集都在资源部分提供,您只需要一个免费的拥抱脸和一个免费的 Google Colab 或 Kaggle 帐户。

介绍

考虑到你已经深入研究LLMs了,你已经知道什么是大型语言模型或它是如何训练的,所以请随时跳过这一部分,如果你只想学习编码部分,请跳到这一部分。

安德烈·卡帕西(Andre Karpathy)也制作了这个惊人的视频LLMs。本文的介绍主要基于它。

大型语言模型(LLM)是进行下一个标记预测的深度学习模型,它们将文本作为输入并预测句子中下一个最可能的单词。

下一个代币预测

然后,他们获取这个新标记,将其添加到上一个输入的末尾,并预测下一个输入。这种情况一直持续到预测的代币是 EOS(序列结束)代币,此时它会停止。

训练这些LLMs需要大量的数据和计算,例如,Llama 2 70b 在 10TB 文本和 6000 个 GPU 上训练了 12 天,花费了大约 200 万美元,最终,我们得到的是一个完成文档的模型。此步骤称为预训练。

骆驼 2 70b

为了让它像助手一样说话,模型稍后会使用基于对话的数据进行训练,在此步骤中,模型将学习像助手一样行事,并使用在相关步骤中学到的所有知识。这为我们提供了指令模型。

此外,为了增加更多知识或增强模型在某些领域或更多领域的能力,我们根据新数据对模型进行微调,与之前的步骤相比,这种微调需要的数据和计算要少得多,但在消费级硬件上仍然不可能。

PEFT(参数高效微调)解决了这个问题,并应用了巧妙的方法,即使在 Google Colab 的免费套餐上也可以进行微调。

我们将使用量化和 LoRA(Low-Rank Adaptation)来微调 Mistral 7b,指导并向其介绍新知识。

LoRA(低等级适配)

LoRA 冻结模型的参数 (W0), 将小的可训练适配器层 (ΔW = BA) 附加到模型上, 并且只训练适配器.这大大减少了可训练参数,并消耗了更少的 RAM。LoRA 中一个重要的超参数是 r,在本例中为 r=2。

因此,我们将使用这种方法并对我们的数据进行微调 Mistral。

准备数据

数据集可在此处找到,因此您不需要在本节中运行任何代码,但我建议您阅读它以了解数据集是如何制作的。

如前所述,我们将对 Enlighten 的代码库进行微调,首先我在每个类中编写了一些关于该类和所有方法的文档。下面是其中一个类作为示例。

using DG.Tweening;
using UnityEngine;
using UnityEngine.Events;

/*
 * Player.Interactables.AnimatedInteractable
 * InteractableObject.cs is the base(abstract) class for all interactable objects in the game. they all must inherit from it or one of its children
 * all interactable objects have a child of InteractableObject.cs class attached to them
 * each script that inherits from InteractableObject.cs has its own custom logic for when the player is focusing on it and when the player interacts with it
 * this class (AnimatedInteractable.cs) inherits from InteractableObject.cs and adds the functionality of playing an animation when the player interacts with the object
 * other scripts can subscribe to the onInteractAction event to add custom logic when the player interacts with the object
 * gameObjects with this script attached to them must have an animator component with a trigger parameter called "OnInteract" and an animation that plays when the trigger is called
 */

[RequireComponent(typeof(Animator))]
public class AnimatedInteractable : InteractableObject {
    private Animator _animator;

    [Tooltip("If true, the object will only be animated once then disabled.")] 
    [SerializeField] private bool isOneTimeAnimated;
    
    //cooldown between each interaction. If 0, there is no cooldown
    [SerializeField] private float cooldown;

    //the action to invoke when the player interacts with the object. Set in the inspector
    [SerializeField] private UnityEvent onInteractAction;
    
    [SerializeField] private AudioSource audioSource;


    private void Start() {
        _animator = GetComponent<Animator>();
    }

    //player is no longer focusing on the current interactable object. child classes can override this method to add custom logic
    protected override void OnObjectWentOutOfFocus() { }


    //Called by PlayerInteractableObjectsManager.cs when the player presses the interact button while focusing on the object. Plays the animation and invokes the onInteractAction
    public override void Interact() {
        //play the animation
        audioSource.Play();
        _animator.SetTrigger("OnInteract");

        //if the object is one time animated, disable the collider so the player can't interact with it again
        if (isOneTimeAnimated) GetComponent<Collider>().enabled = false;

        //if it has a cooldown, disable the collider for the duration of the cooldown
        else if (cooldown != 0) {
            GetComponent<Collider>().enabled = false;
            DOTween.Sequence().AppendInterval(cooldown).OnComplete(() => { GetComponent<Collider>().enabled = true; });
        }

        //invoke the onInteractAction
        onInteractAction?.Invoke();
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

正如你所看到的,我们的数据目前只是一堆 C# 类,但我们需要基于指令的数据,以问答对的形式出现(这些原始 C# 类可用于微调基本的非指令模型,但生成的模型也将是非指令的,只会完成一个文档,主要用于代码完成)

为了解决这个问题,我们将使用一个更大、更强大的模型来基于代码库合成生成我们的数据,我选择了新发布的 Google Gemini Pro 来完成这项任务,因为它对于这个用例既免费又强大。(Gpt-4 将是最好的模型,但 API 不是免费的)。

我们需要两个数据集,一个用于训练,一个用于测试,两者都将由 Gemini 合成生成。训练数据将是一个关于代码及其答案的问题。测试数据集将采用多项选择题的形式,这是一个问题,后跟 4 个选项和正确的一个。

对于每个 C# 类,我生成了 20 个用于训练的 Q/A,生成了 3 个用于测试的 Q/A。20 个培训问题包括 10 个 just-code 问题和 10 个一般问题。总共有大约 90 个 C# 类。

现在我们将把每个类都交给 Gemini 并要求它生成我们的数据,为此,我们需要我们自己的自定义系统消息。这些是我设计的系统消息。

10 个纯代码问题的系统消息(训练数据)

# you take a C# class from a Unity project with it's documentaion and create 10 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# all questions should ask for code and all answers should be C# code.
# questions should have context. 
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 10 objects, each object has 3 text fields: class name, Q(question), A(answer)
# any '\n' in the text fields MUST be '\\n' so that when reading it later on, we won't run into any issues

# example output with 2 question/answer
[
  {
    "class": "className.cs",
    "Q": "in 'className.cs' How does the beast check if the player is in sight?", 
    "A": " ```csharp\\nVector3 direction = (player.position - new Vector3(0, 0.5f, 0)) - beastTransform.position;\\nif (Physics.Raycast(beastTransform.position, direction, out hit, eyeSightDistance, layerMask)) {\\n  if (hit.collider.CompareTag("Player")) {\\n    return true;\\n  }\\n}\\n```"
  },
  {
    "class": "className.cs",
    "Q": "What is the code used to calculate the distance between the player and the beast in 'className.cs' ?",
    "A": " ```csharp\\nif (navMeshAgent.remainingDistance > distanceToPlayerThatBeastRuns)\\n    navMeshAgent.speed = Constants.BeastFastSpeed;\\nelse navMeshAgent.speed = normalSpeed;\\n```"
  }
]
# end of examples.

# this is the C# class:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

10 个一般问题的系统消息(训练数据)

# you take a C# class from a Unity project with it's documentaion and create 10 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# questions should have context. 
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 10 objects, each object has 3 text fields: class name, Q(question), A(answer)
# any '\n' in the text fields MUST be '\\n' so that when reading it later on, we won't run into any issues

# example output with 2 question/answer
[
  {
    "class": "className.cs",
    "Q": "What is the purpose of the className.cs class?", 
    "A": "The className.cs class is the main controller for the beast. It manages the state of the beast and the transitions between them.\\n it is implemented in singleton pattern"
  },
  {
    "class": "className.cs", 
    "Q": "in 'className.cs' What is the purpose of the _roamingState variable?",
    "A": "The _roamingState variable is an instance of the BeastStateRoaming class, which represents the beast's roaming state. It manages the behavior and transitions related to the roaming state, including moving between predefined roaming positions."
  }
]
# end of examples.

# this is the C# class:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

3 道多项选择题的系统信息(测试数据)

# you take a C# class from a Unity project with it's documentaion and create 3 question/answer pairs for it.
# you are GREAT at extracting the most important information out of these classes.
# DO NOT say 'the code' or 'this code' in your questions and do not refer to it at all.
# mention the class name in all questions.
# your answer MUST be a valid json format, a list of 3 objects, each object has 6 text fields: class name, Question, a,b,c,d,Answer

# example output with 2 question/answer
[
  {
    "class": "className.cs",
    "Question": "In className.cs what is the purpose of the PlayerManager class?", 
    "a": "To control player movement", 
    "b": "To manage some player behavior functionality", 
    "c": "To handle player combat actions", 
    "d": "To store references to key player components", 
    "Answer": "b"
  },
  {
    "class": "className.cs",
    "Question": "What does the FarthestPlaceFromPlayer() method do in className.cs?", 
    "a": "Finds the farthest destination from the player", 
    "b": "Teleports the player", 
    "c": "Returns a random destination", 
    "d": "Sets the player position", 
    "Answer": "a"
  }
]
# end of examples.

# this is the C# class:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

在这些系统消息中LLM,我们首先告诉它我们希望它做什么,然后告诉它规则。响应是有效的 JSON,这使我们以后更容易。我还使用了小样本提示技术,给出LLM了一个示例响应,以便其输出更符合我们的需要。

DataGenerator.ipynb 完成所有这些操作,从读取所有 C# 类到生成合成数据,并将其另存为 CSV 文件。我们不会全部介绍,因为它不是本文的主要重点,但这两个代码块基本上是它的核心。

如何调用 Gemini API

genai.configure(api_key=geminiApiKey)
model = genai.GenerativeModel('gemini-pro')

def get_raw_text_gemini(file_content,systemMessage):
    response = model.generate_content(systemMessage+"\n\n"+file_content
                                    ,generation_config=genai.types.GenerationConfig(max_output_tokens=4000))

    return(response.text)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

将响应转换为 LLM Pandas 数据帧

def make_df(text):
    data = json.loads(text)
    df = pd.DataFrame(data)
    df=df.map(lambda x: x.replace('\\n', '\n'))
    return df


raw_response=get_raw_text_gemini(file_content,test_systemMessage)

df=make_df(raw_response)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

我还在训练数据集中添加了一些非合成数据。首先,我写了一些关于 Enlighten 项目的一般信息,然后为每个类添加了以下内容:

问题:编写“ClassName”类

答:整个ClassName.cs代码

最后,训练数据大约有一百万个令牌,我们得到了两个 CSV 文件,TestData.csv 和 TrainData.csv

使用 LoRA 进行微调

编码开始,您可以在 Google Colab 中运行此代码(整个微调笔记本),但首先更改运行时类型并激活 T4 GPU(如果您使用的是 Kaggle,请激活 P100 GPU)。让我们从声明一些变量开始

base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "Enlighten_Instruct"

test_path='/content/Enlighten-Instruct/Dataset/TestData.csv'
train_path='/content/Enlighten-Instruct/Dataset/TrainData.csv'
  • 1
  • 2
  • 3
  • 4
  • 5

然后我们安装一些包,克隆 git 存储库(仅用于数据集),并导入库

%%capture
!git clone 'https://github.com/ali7919/Enlighten-Instruct.git'
!pip install -U bitsandbytes
!pip install transformers==4.36.2
!pip install -U peft
!pip install -U accelerate
!pip install -U trl
!pip install datasets==2.16.0
!pip install sentencepiece
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pandas as pd
from datasets import Dataset
import re
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

要将最终结果上传到 Hugging Face,我们必须首先登录它,我们将使用“密钥”,首先从左侧工具栏中选择“密钥”选项卡

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/754478
推荐阅读
相关标签