赞
踩
基于DistilBERT的标题多分类任务
本项目旨在使用DistilBERT模型对给定的标题文本进行多分类任务。项目包括从数据处理、模型训练、模型评估到最终的API部署。该项目采用模块化设计,以便于理解和维护。
.
├── bert_data
│ ├── train.txt
│ ├── dev.txt
│ └── test.txt
├── saved_model
├── results
├── logs
├── data_processing.py
├── dataset.py
├── training.py
├── app.py
└── main.py
bert_data/:存放训练集、验证集和测试集的数据文件。
saved_model/:存放训练好的模型和tokenizer。
results/:存放训练结果。
logs/:存放训练日志。
data_processing.py:数据处理模块,负责读取和预处理数据。
dataset.py:数据集类模块,定义了用于训练和评估的数据集类。
training.py:模型训练模块,定义了训练和评估模型的过程。
app.py:模型部署模块,使用FastAPI创建API服务。
main.py:主脚本,运行整个流程,包括数据处理、模型训练和部署。
为了确保数据处理和模型训练的顺利进行,请按照以下规范准备数据集文件。每个文件包含的标题和标签分别使用制表符(\t
)分隔。以下是一个示例数据集的格式。
数据文件应为纯文本文件,扩展名为.txt
,文件内容的每一行应包含一个文本标题和一个对应的分类标签,用制表符分隔。数据文件不应包含表头。
探索神秘的海底世界 7
如何在家中制作美味披萨 2
全球气候变化的原因和影响 1
最新的智能手机评测 8
健康饮食:如何搭配均衡的膳食 5
最受欢迎的电影和电视剧推荐 3
了解宇宙的奥秘:天文学入门 0
如何种植和照顾多肉植物 9
时尚潮流:今年夏天的必备单品 6
如何有效管理个人财务 4
通过以上规范和示例数据文件创建方法,可以确保数据文件符合项目需求,并顺利进行数据处理和模型训练。
功能:读取数据文件并进行预处理。
load_data(file_path)
: 读取指定路径的数据文件,并返回一个包含文本和标签的数据框。tokenize_data(data, tokenizer, max_length=128)
: 使用BERT的tokenizer对数据进行tokenize处理。main()
: 加载数据、tokenize数据并返回处理后的数据。功能:定义数据集类,便于模型训练。
TextDataset
: 将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。功能:定义训练和评估模型的过程。
train_model()
: 加载数据和tokenizer,创建数据集,加载模型,设置训练参数,定义Trainer,训练和评估模型,保存训练好的模型和tokenizer。功能:使用FastAPI进行模型部署。
predict(item: Item)
: 接收POST请求的文本输入,使用训练好的模型进行预测并返回分类结果。功能:运行整个流程,包括数据处理、模型训练和部署。
main()
: 运行模型训练流程,并输出训练完成的提示。pip install pandas torch transformers fastapi uvicorn scikit-learn
确保bert_data
文件夹下包含train.txt
、dev.txt
和test.txt
文件,每个文件包含文本和标签,使用制表符分隔。
运行main.py
脚本,进行数据处理和模型训练:
python main.py
训练完成后,模型和tokenizer将保存在saved_model
文件夹中。
运行app.py
脚本,启动API服务:
uvicorn app:app --reload
服务启动后,可以通过POST请求访问预测接口,进行文本分类预测。
curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{"text": "你的文本"}'
返回示例:
{
"prediction": 3
}
nvidia-smi
监控显存使用,避免显存溢出。功能:读取数据文件并进行预处理。
# data_processing.py import pandas as pd from transformers import DistilBertTokenizer def load_data(file_path): data = pd.read_csv(file_path, delimiter='\t', header=None) data.columns = ['text', 'label'] return data def tokenize_data(data, tokenizer, max_length=128): encodings = tokenizer(list(data['text']), truncation=True, padding=True, max_length=max_length) return encodings def main(): # 加载Tokenizer tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-chinese') # 加载数据 train_data = load_data('./bert_data/train.txt') dev_data = load_data('./bert_data/dev.txt') test_data = load_data('./bert_data/test.txt') # Tokenize数据 train_encodings = tokenize_data(train_data, tokenizer) dev_encodings = tokenize_data(dev_data, tokenizer) test_encodings = tokenize_data(test_data, tokenizer) return train_encodings, dev_encodings, test_encodings, train_data['label'], dev_data['label'], test_data['label'] if __name__ == "__main__": main()
功能:定义数据集类,便于模型训练。
# dataset.py
import torch
class TextDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
功能:定义训练和评估模型的过程。
# training.py import torch from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments from dataset import TextDataset import data_processing def train_model(): # 加载数据和tokenizer train_encodings, dev_encodings, test_encodings, train_labels, dev_labels, test_labels = data_processing.main() # 创建数据集 train_dataset = TextDataset(train_encodings, train_labels) dev_dataset = TextDataset(dev_encodings, dev_labels) test_dataset = TextDataset(test_encodings, test_labels) # 加载DistilBERT模型 model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-chinese', num_labels=10) model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # 设置训练参数 training_args = TrainingArguments( output_dir='./results', # 输出结果目录 num_train_epochs=3, # 训练轮数 per_device_train_batch_size=16, # 训练时每个设备的批量大小 per_device_eval_batch_size=64, # 验证时每个设备的批量大小 warmup_steps=500, # 训练步数 weight_decay=0.01, # 权重衰减 logging_dir='./logs', # 日志目录 fp16=True, # 启用混合精度训练 ) # 定义Trainer trainer = Trainer( model=model, # 预训练模型 args=training_args, # 训练参数 train_dataset=train_dataset, # 训练数据集 eval_dataset=dev_dataset # 验证数据集 ) # 训练模型 trainer.train() # 评估模型 eval_results = trainer.evaluate() print(eval_results) # 保存模型 trainer.save_model('./saved_model') tokenizer = trainer.tokenizer tokenizer.save_pretrained('./saved_model') if __name__ == "__main__": train_model()
功能:使用FastAPI进行模型部署。
# app.py from fastapi import FastAPI from pydantic import BaseModel from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch app = FastAPI() # 加载模型和tokenizer model = DistilBertForSequenceClassification.from_pretrained('./saved_model') tokenizer = DistilBertTokenizer.from_pretrained('./saved_model') class Item(BaseModel): text: str @app.post("/predict") def predict(item: Item): inputs = tokenizer(item.text, return_tensors="pt", max_length=128, padding='max_length', truncation=True) outputs = model(**inputs) prediction = torch.argmax(outputs.logits, dim=1) return {"prediction": prediction.item()} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)
功能:运行整个流程,包括数据处理、模型训练和部署。
# main.py
import training
def main():
# 训练模型
training.train_model()
print("模型训练完成并保存。")
if __name__ == "__main__":
main()
# client.py import requests def predict(text): url = "http://localhost:8000/predict" payload = {"text": text} headers = {"Content-Type": "application/json"} response = requests.post(url, json=payload, headers=headers) if response.status_code == 200: prediction = response.json() return prediction else: print(f"Error: {response.status_code}") print(response.text) return None if __name__ == "__main__": text_to_predict = "探索神秘的海底世界" prediction = predict(text_to_predict) if prediction: print(f"Prediction: {prediction['prediction']}")
数据处理模块:
数据集类模块:
TextDataset
类,用于将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。模型训练模块:
Trainer
进行模型训练和评估,并保存训练好的模型。模型部署模块:
主脚本:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。