使用 HuggingFace 中的 Trainer 进行 BERT 模型微调,太方便了!!!

2024-06-08 11:36

本文主要是介绍使用 HuggingFace 中的 Trainer 进行 BERT 模型微调,太方便了!!!,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:

  • 《大模型面试宝典》(2024版) 发布!
  • 圈粉无数!《PyTorch 实战宝典》火了!!!

以往,我们在使用HuggingFace在训练BERT模型时,代码写得比较复杂,涉及到数据处理、token编码、模型编码、模型训练等步骤,从事NLP领域的人都有这种切身感受。

事实上,HugggingFace中提供了datasets模块(数据处理)和Trainer函数,使得我们的模型训练较为方便。

本文将会介绍如何使用HuggingFace中的Trainer对BERT模型微调。

Trainer

Trainer是HuggingFace中的模型训练函数,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer 。

Trainer的传入参数如下:

model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None
args: TrainingArguments = None
data_collator: typing.Optional[DataCollator] = None
train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None
eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None
tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None
compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None
callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None
optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)
preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )

参数解释:

  • model为预训练模型

  • args为TrainingArguments(训练参数)类

  • data_collator会将数据集中的元素组成一个batch,默认使用default_data_collator(),如果tokenizer没有提供,则使用DataCollatorWithPadding

  • train_dataset, eval_dataset为训练集,验证集

  • tokenizer为模型训练使用的tokenizer

  • model_init为模型初始化

  • compute_metrics为验证集的评估指标计算函数

  • callbacks为训练过程中的callback列表

  • optimizers为模型训练中的优化器

  • preprocess_logits_for_metrics为模型评估阶段前对logits的预处理

TrainingArguments为训练参数类,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments,传入参数非常多(transformers版本4.32.1中有98个参数!),我们在这里只介绍几个常见的:

output_dir: stroverwrite_output_dir: bool = False
evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'
per_gpu_train_batch_size: typing.Optional[int] = None
per_gpu_eval_batch_size: typing.Optional[int] = None
learning_rate: float = 5e-05
num_train_epochs: float = 3.0
logging_dir: typing.Optional[str] = None
logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'
save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'save_steps: float = 500
report_to: typing.Optional[typing.List[str]] = None

参数解释:

  • output_dir为模型输出目录

  • evaluation_strategy为模型评估策略

    1. “no": 不做模型评估

    2. “steps”: 按训练步数(steps)进行评估,需指定步数

    3. “epoch”: 每个epoch训练完后进行评估

  • per_gpu_train_batch_size, per_gpu_eval_batch_size为每个GPU上训练集和测试集的batch size,也有CPU上的对应参数

  • learning_rate为学习率

  • logging_dir为日志输出目录

  • logging_strategy为日志输出策略,同样有no, steps, epoch三种,意义同上

  • save_strategy为模型保存策略,同样有no, steps, epoch三种,意义同上

  • report_to为模型训练、评估中的重要指标(如loss, accurace)输出之处,可选择azure_ml, clearml, codecarbon, comet_ml, dagshub, flyte, mlflow, neptune, tensorboard, wandb,使用all会输出到所有的地方,使用no则不会输出。

下面我们使用Trainer进行BERT模型微调,给出英语、中文数据集上文本分类的示例代码。

BERT 微调

使用datasets模块导入imdb数据集(英语影评数据集,常用于文本分类),加载预训练模型bert-base-cased的tokenizer。

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
raw_datasets = datasets.load_dataset('imdb')

查看数据集,有train(训练集)、test(测试集)、unsupervised(非监督)三部分,我们这里使用训练集和测试集,各自有25000个样本。

raw_datasets
DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 25000})test: Dataset({features: ['text', 'label'],num_rows: 25000})unsupervised: Dataset({features: ['text', 'label'],num_rows: 50000})
})

创建数据tokenize函数,对文本进行tokenize,最大长度设置为300,同时使用data_collector为DataCollatorWithPadding。

def tokenize_function(sample):return tokenizer(sample['text'], max_length=300, truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

加载分类模型,输出类别为2.

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

设置compute_metrics函数,在评估过程中输出accuracy, f1, precision, recall四个指标。设置训练参数TrainingArguments类,设置Trainer。

from transformers import Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_supportdef compute_metrics(pred):labels = pred.label_idspreds = pred.predictions.argmax(-1)precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')acc = accuracy_score(labels, preds)return {'accuracy': acc,'f1': f1,'precision': precision,'recall': recall}training_args = TrainingArguments(output_dir='imdb_test_trainer', # 指定输出文件夹,没有会自动创建evaluation_strategy="epoch",per_device_train_batch_size=32,per_device_eval_batch_size=32,learning_rate=5e-5,num_train_epochs=3,warmup_ratio=0.2,logging_dir='./imdb_train_logs',logging_strategy="epoch",save_strategy="epoch",report_to="tensorboard") trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["test"],data_collator=data_collator,  # 在定义了tokenizer之后,其实这里的data_collator就不用再写了,会自动根据tokenizer创建tokenizer=tokenizer,compute_metrics=compute_metrics
)

开启模型训练。

trainer.train()

输出结果如下:

EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.3643000.2232230.9106000.9105090.9122760.910600
20.1648000.2044200.9239600.9239410.9243750.923960
30.0710000.2413500.9255200.9255100.9257590.925520

以上为英语数据集的文本分类模型微调。

中文数据集使用sougou-mini数据集(训练集4000个样本,测试集495个样本,共5个输出类别),预训练模型采用bert-base-chinese。代码基本与英语数据集差不多,只要修改 预训练模型,数据集加载 和 最大长度为128,输出类别。以下是不同的代码之处:

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)data_files = {"train": "./data/sougou/train.csv", "test": "./data/sougou/test.csv"}
raw_datasets = datasets.load_dataset("csv", data_files=data_files, delimiter=",")
...
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=5)
...

输出结果如下:

EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.8492000.1151890.9696970.9694490.9700730.969697
20.1069000.0939870.9737370.9737700.9753720.973737
30.0478000.0788610.9737370.9737400.9741170.973737

模型评估

在上述模型评估过程中,已经有了模型评估的各项指标。
本文也给出单独做模型评估的代码,方便后续对模型做量化时(后续介绍BERT模型的动态量化)获取量化前后模型推理的各项指标。
中文数据集文本分类模型评估代码如下:

import torch
from transformers import AutoModelForSequenceClassificationMAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)from transformers import AutoTokenizer, DataCollatorWithPaddingtokenizer = AutoTokenizer.from_pretrained(checkpoint)import pandas as pdtest_df = pd.read_csv("./data/sougou/test.csv")
test_df.head()
import numpy as np
import times_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():row_s_time = time.time()true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)# print(encoded_text)logits = model(**encoded_text)label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]pred_labels.append(label_id)if i % 100 == 0:print(i, (time.time() - row_s_time)*1000, label_id)print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])

0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time: 352.44047810332944

true_labels[:10]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

pred_labels[:10]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

from sklearn.metrics import classification_reportprint(classification_report(true_labels, pred_labels, digits=4))

输出结果如下:

              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9691    0.9495    0.9592        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9895    0.9495    0.9691        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495

总结

本文介绍了如何使用HuggingFace中的Trainer对BERT模型微调。可以看到,使用Trainer进行模型微调,代码较为简洁,且支持功能丰富,是理想的模型训练方式。

这篇关于使用 HuggingFace 中的 Trainer 进行 BERT 模型微调,太方便了!!!的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1042074

相关文章

鸿蒙中@State的原理使用详解(HarmonyOS 5)

《鸿蒙中@State的原理使用详解(HarmonyOS5)》@State是HarmonyOSArkTS框架中用于管理组件状态的核心装饰器,其核心作用是实现数据驱动UI的响应式编程模式,本文给大家介绍... 目录一、@State在鸿蒙中是做什么的?二、@Spythontate的基本原理1. 依赖关系的收集2.

Python基础语法中defaultdict的使用小结

《Python基础语法中defaultdict的使用小结》Python的defaultdict是collections模块中提供的一种特殊的字典类型,它与普通的字典(dict)有着相似的功能,本文主要... 目录示例1示例2python的defaultdict是collections模块中提供的一种特殊的字

C++ Sort函数使用场景分析

《C++Sort函数使用场景分析》sort函数是algorithm库下的一个函数,sort函数是不稳定的,即大小相同的元素在排序后相对顺序可能发生改变,如果某些场景需要保持相同元素间的相对顺序,可使... 目录C++ Sort函数详解一、sort函数调用的两种方式二、sort函数使用场景三、sort函数排序

Java String字符串的常用使用方法

《JavaString字符串的常用使用方法》String是JDK提供的一个类,是引用类型,并不是基本的数据类型,String用于字符串操作,在之前学习c语言的时候,对于一些字符串,会初始化字符数组表... 目录一、什么是String二、如何定义一个String1. 用双引号定义2. 通过构造函数定义三、St

SpringSecurity6.0 如何通过JWTtoken进行认证授权

《SpringSecurity6.0如何通过JWTtoken进行认证授权》:本文主要介绍SpringSecurity6.0通过JWTtoken进行认证授权的过程,本文给大家介绍的非常详细,感兴趣... 目录项目依赖认证UserDetailService生成JWT token权限控制小结之前写过一个文章,从S

Pydantic中Optional 和Union类型的使用

《Pydantic中Optional和Union类型的使用》本文主要介绍了Pydantic中Optional和Union类型的使用,这两者在处理可选字段和多类型字段时尤为重要,文中通过示例代码介绍的... 目录简介Optional 类型Union 类型Optional 和 Union 的组合总结简介Pyd

Vue3使用router,params传参为空问题

《Vue3使用router,params传参为空问题》:本文主要介绍Vue3使用router,params传参为空问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录vue3使用China编程router,params传参为空1.使用query方式传参2.使用 Histo

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

使用Python自建轻量级的HTTP调试工具

《使用Python自建轻量级的HTTP调试工具》这篇文章主要为大家详细介绍了如何使用Python自建一个轻量级的HTTP调试工具,文中的示例代码讲解详细,感兴趣的小伙伴可以参考一下... 目录一、为什么需要自建工具二、核心功能设计三、技术选型四、分步实现五、进阶优化技巧六、使用示例七、性能对比八、扩展方向建

使用Python实现一键隐藏屏幕并锁定输入

《使用Python实现一键隐藏屏幕并锁定输入》本文主要介绍了使用Python编写一个一键隐藏屏幕并锁定输入的黑科技程序,能够在指定热键触发后立即遮挡屏幕,并禁止一切键盘鼠标输入,这样就再也不用担心自己... 目录1. 概述2. 功能亮点3.代码实现4.使用方法5. 展示效果6. 代码优化与拓展7. 总结1.