本文主要是介绍如何使用自有数据微调ChatGLM-6B,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
目录
构建自己的数据集
数据格式:问答对
官网例子
调整自有样本格式
数据集划分和数据量
微调方法
脚本参数如何修改
官网对于参数的解释:
如何防止过拟合和灾难遗忘
微调后效果评估
微调方法简介(理论)
构建自己的数据集
数据格式:问答对
官网例子
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
{ "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳", "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
调整自有样本格式
结合您的任务场景,构建问题和答案,以提取关键字为例。
[
{ "content": "请提取下面句子的关键字:'''离岸人民币 (CNH) 兑美元北京时间04:59报7.1657元,较周二纽约尾盘上涨462点,盘中整体交投于7.2155-7.1617元区间。截至发稿,离岸人民币暂报7.1656,升值12基点。Wind数据显示,7月份以来美元指数持续下跌,12日更是大跌1.06%。与此同时,在岸、离岸人民币对美元汇率迎来反弹,在岸、离岸人民币双双收复7.2关。自上周四起,人民币对美元连续走高,截至12日收盘,离岸人民币对美元5个交易日累计涨幅达956个基点。'''", "summary": "离岸人民币 反弹"
},
{ "content": "请提取下面句子的关键字:'''连日来,人民币汇率强势回升,给市场留下深刻“记忆”。昨日晚间,在岸、离岸人民币汇率双双收复7.17关口,其中,离岸人民币汇率日内大涨近400点。而近一周以来,在岸、离岸人民币汇率强势回升近千点,升幅均超过1%。关于人民币汇率,中国人民银行行长易纲近期在《经济研究》发表《货币政策的自主性、有效性与经济金融稳定》一文。易纲在文中指出,“近年来人民币汇率弹性显著增强,提高了利率调控的自主性,促进了宏观经济稳定,经济基本面稳定又对汇率稳定形成支撑,外汇市场运行更有韧性,利率和汇率之间形成良性互动。”'''", "summary": "人民币汇率 回升 强势 弹性 韧性 支撑"
},
]
数据集划分和数据量
样本总量至少几百个,根据具体task调整,太少可能效果不好。
除了准备训练集,您还要准备一个验证集,测试与验证的比例可参考8:2、9:1等。如果您有测试集更好了,可以评测下效果。
微调方法
选用官网ptuningv2微调方法,去学习提示向量。对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
脚本参数如何修改
# P-tuning v2
!PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file AdvertiseGen/train.json \--validation_file AdvertiseGen/dev.json \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /home/mw/input/ChatGLM6B6449 \--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_train_batch_size 4 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 4 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
# 重要参数注释
PRE_SEQ_LEN=128 #前缀长度,前缀长度不占整个提示的输入token
LR=2e-2 # 学习率
--train_file AdvertiseGen/train.json \ # 训练集样本路径
--validation_file AdvertiseGen/dev.json \ # 验证集样本路径
--prompt_column content \ # 样本集json文件中问题/提示的key
--response_column summary \ # 样本集json文件中答案的key
--max_source_length 64 \ # 输入的token最大长度
--max_target_length 64 \ # 输出的token最大长度
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--max_steps 3000 \
--quantization_bit 4 # 使用int4量化--model_name_or_path /home/mw/input/ChatGLM6B6449 \ # 原始预训练模型文件存放位置
--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ # 微调生成的模型检查点/训练后的模型参数存放位置;后续推理预测时,需要再次加载此目录下的最后一个文件夹
官网对于参数的解释:
P-tuning v2
PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
上文以 P-tuning v2 方法采取的参数 quantization_bit=4、per_device_train_batch_size=4、gradient_accumulation_steps=4
如何防止过拟合和灾难遗忘
训练的epoch(轮数)不要太多。具体多少是个玄学,如果训练样本有很少,建议先试试1轮,模型效果不收敛再尝试加大epoch数,如果过拟合了,请降低epoch数;如果样本很多,1w+,建议先试试0.5轮。
chatglm-6b的epoch数的计算方法:
per_device_train_batch_size*gradient_accumulation_steps*max_steps/样本总数
训练日志里会输出epoch的数据,大家可以仔细观察。
微调后效果评估
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 假如大约12G左右的预训练模型文件放在如下目录
model_path = "/home/mw/input/ChatGLM6B6449"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Fine-tuning 后的表现测试
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的微调后得到的模型检查点目录,
#例如上文我们设置检查点目录为/home/output/adgen-chatglm-6b-pt-128-2e-2/,还要指定最终的检查点目录
#因为前面脚本设置了每1000step生成一个临时检查点,请使用最终那个检查点目录,即max-steps对应的目录。
prefix_state_dict = torch.load(os.path.join("/home/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)#V100 机型上可以不进行量化
#print(f"Quantized to 4 bit")
#model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response
# 以下是微调前的效果评估,即没有加载微调后的检查点文件。
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response
以上如有问题,欢迎评论区回复和交流~
微调方法简介(理论)
p*tuning 论文综述https://arxiv.org/pdf/2107.13586.pdf
p-tuning v1 论文https://arxiv.org/pdf/2103.10385.pdf
p-tuningv2 论文https://arxiv.org/pdf/2110.07602.pdf
prefix-tuning 论文https://aclanthology.org/2021.acl-long.353.pdf
Prompt Tuning 论文https://arxiv.org/pdf/2104.08691.pdf
以上方法均属于软提示(连续提示),区别硬提示(提示工程,或者理解为用自然语言去提示,也叫离散提示)。
现在大模型相关的良心论文里都开始写“直观感受”来解释他们的微调方法了,真是方便大家理解。下面用简单的直观语言来解释一下:
软提示微调的目标:去自动学习一些参数/向量,来模拟人工提示工程;即让模型在嵌入式空间自己学习一个提示向量,加到原来的输入之前,再去激活大模型的“潜力”。
谷歌的prompt tuning,这个名字有争议,因为目前还没有明确的分类和命名方法。学习一个嵌入向量,拼到原有输入的向量表示之前,然后喂给后续的预训练语言模型。如下如所示。
P-tuning v2:在v1的基础上,每一层transformer都加上一个要学习的提示向量。 如下图。
prefix-tuning:全参微调(下图顶部)会更新所有LM参数(红色的Transformer框),并要求为每个任务存储一个完整的模型副本。前缀调整(下图底部),它冻结LM参数并仅优化前缀(红色前缀块)。因此,只需要存储每个任务的前缀,使前缀调整模块化和空间高效。注意,每个垂直块表示一个时间步长的转换器活动。它也是在每个transformer前添加前缀序列来完成目标的。
目前看调优方法并无明显优劣之分,具体下游任务具体分析,目前细微差别只体现在论文的“相近工作比较中”:谷歌的prompt tuning吐槽前缀调整包括编码器和解码器网络上的前缀,而提示调整只需要编码器上的提示。P-tuning v2吐槽P-tuning v1针对对复杂任务效果不佳,因为v1只调了input层(但谷歌说的prompt tuning也是只调了input层,但是效果也很显著)。
结论P-tuning v1和谷歌的prompt tuning比较类似:只修改了input层(即给transformer的输入层嵌入了一个可学习的提示向量);prefix-tuning和P-tuning v2比较类似:修改了每一个transformer的input层。
专家的回复
1、前缀长度是否占输入token,不占。大家知道chatglm-6b的输入token是有长度限制的,最初是2048个token,ptuningv2这种加前缀的微调方法,并不会挤占输入token,加入前缀并不会导致可用的输入token长度变短。
2、一个token大概对应1.8个汉字。
这篇关于如何使用自有数据微调ChatGLM-6B的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!