如何使用自有数据微调ChatGLM-6B

2023-11-22 17:20

本文主要是介绍如何使用自有数据微调ChatGLM-6B,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

构建自己的数据集

数据格式:问答对

官网例子

调整自有样本格式

数据集划分和数据量

微调方法

脚本参数如何修改

官网对于参数的解释:

如何防止过拟合和灾难遗忘

微调后效果评估

微调方法简介(理论)


构建自己的数据集

数据格式:问答对

官网例子

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{  "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",  "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"  
}

调整自有样本格式

结合您的任务场景,构建问题和答案,以提取关键字为例。

[
{  "content": "请提取下面句子的关键字:'''离岸人民币 (CNH) 兑美元北京时间04:59报7.1657元,较周二纽约尾盘上涨462点,盘中整体交投于7.2155-7.1617元区间。截至发稿,离岸人民币暂报7.1656,升值12基点。Wind数据显示,7月份以来美元指数持续下跌,12日更是大跌1.06%。与此同时,在岸、离岸人民币对美元汇率迎来反弹,在岸、离岸人民币双双收复7.2关。自上周四起,人民币对美元连续走高,截至12日收盘,离岸人民币对美元5个交易日累计涨幅达956个基点。'''",  "summary": "离岸人民币 反弹"  
},
{  "content": "请提取下面句子的关键字:'''连日来,人民币汇率强势回升,给市场留下深刻“记忆”。昨日晚间,在岸、离岸人民币汇率双双收复7.17关口,其中,离岸人民币汇率日内大涨近400点。而近一周以来,在岸、离岸人民币汇率强势回升近千点,升幅均超过1%。关于人民币汇率,中国人民银行行长易纲近期在《经济研究》发表《货币政策的自主性、有效性与经济金融稳定》一文。易纲在文中指出,“近年来人民币汇率弹性显著增强,提高了利率调控的自主性,促进了宏观经济稳定,经济基本面稳定又对汇率稳定形成支撑,外汇市场运行更有韧性,利率和汇率之间形成良性互动。”'''",  "summary": "人民币汇率 回升 强势 弹性 韧性 支撑"  
},
]

数据集划分和数据量

样本总量至少几百个,根据具体task调整,太少可能效果不好。

除了准备训练集,您还要准备一个验证集,测试与验证的比例可参考8:2、9:1等。如果您有测试集更好了,可以评测下效果。

微调方法

选用官网ptuningv2微调方法,去学习提示向量。对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

脚本参数如何修改

# P-tuning v2
!PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python3 main.py \--do_train \--train_file AdvertiseGen/train.json \--validation_file AdvertiseGen/dev.json \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path /home/mw/input/ChatGLM6B6449 \--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_train_batch_size 4 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 4 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
# 重要参数注释
PRE_SEQ_LEN=128 #前缀长度,前缀长度不占整个提示的输入token
LR=2e-2 # 学习率
--train_file AdvertiseGen/train.json \ # 训练集样本路径
--validation_file AdvertiseGen/dev.json \ # 验证集样本路径
--prompt_column content \ # 样本集json文件中问题/提示的key
--response_column summary \ # 样本集json文件中答案的key
--max_source_length 64 \ # 输入的token最大长度
--max_target_length 64 \ # 输出的token最大长度
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--max_steps 3000 \
--quantization_bit 4 # 使用int4量化--model_name_or_path /home/mw/input/ChatGLM6B6449 \ # 原始预训练模型文件存放位置
--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ # 微调生成的模型检查点/训练后的模型参数存放位置;后续推理预测时,需要再次加载此目录下的最后一个文件夹

官网对于参数的解释:

P-tuning v2
PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

上文以 P-tuning v2 方法采取的参数 quantization_bit=4、per_device_train_batch_size=4、gradient_accumulation_steps=4

如何防止过拟合和灾难遗忘

训练的epoch(轮数)不要太多。具体多少是个玄学,如果训练样本有很少,建议先试试1轮,模型效果不收敛再尝试加大epoch数,如果过拟合了,请降低epoch数;如果样本很多,1w+,建议先试试0.5轮。

chatglm-6b的epoch数的计算方法:
per_device_train_batch_size*gradient_accumulation_steps*max_steps/样本总数

训练日志里会输出epoch的数据,大家可以仔细观察。

微调后效果评估

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 假如大约12G左右的预训练模型文件放在如下目录
model_path = "/home/mw/input/ChatGLM6B6449"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Fine-tuning 后的表现测试
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的微调后得到的模型检查点目录,
#例如上文我们设置检查点目录为/home/output/adgen-chatglm-6b-pt-128-2e-2/,还要指定最终的检查点目录
#因为前面脚本设置了每1000step生成一个临时检查点,请使用最终那个检查点目录,即max-steps对应的目录。
prefix_state_dict = torch.load(os.path.join("/home/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)#V100 机型上可以不进行量化
#print(f"Quantized to 4 bit")
#model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response
# 以下是微调前的效果评估,即没有加载微调后的检查点文件。
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response

以上如有问题,欢迎评论区回复和交流~

微调方法简介(理论)

p*tuning 论文综述https://arxiv.org/pdf/2107.13586.pdf
p-tuning v1 论文https://arxiv.org/pdf/2103.10385.pdf
p-tuningv2 论文https://arxiv.org/pdf/2110.07602.pdf
prefix-tuning 论文https://aclanthology.org/2021.acl-long.353.pdf
Prompt Tuning 论文https://arxiv.org/pdf/2104.08691.pdf​

 以上方法均属于软提示(连续提示),区别硬提示(提示工程,或者理解为用自然语言去提示,也叫离散提示)。

现在大模型相关的良心论文里都开始写“直观感受”来解释他们的微调方法了,真是方便大家理解。下面用简单的直观语言来解释一下:

软提示微调的目标:去自动学习一些参数/向量,来模拟人工提示工程;即让模型在嵌入式空间自己学习一个提示向量,加到原来的输入之前,再去激活大模型的“潜力”。

谷歌的prompt tuning,这个名字有争议,因为目前还没有明确的分类和命名方法。学习一个嵌入向量,拼到原有输入的向量表示之前,然后喂给后续的预训练语言模型。如下如所示。

P-tuning v2:在v1的基础上,每一层transformer都加上一个要学习的提示向量。 如下图。

prefix-tuning:全参微调(下图顶部)会更新所有LM参数(红色的Transformer框),并要求为每个任务存储一个完整的模型副本。前缀调整(下图底部),它冻结LM参数并仅优化前缀(红色前缀块)。因此,只需要存储每个任务的前缀,使前缀调整模块化和空间高效。注意,每个垂直块表示一个时间步长的转换器活动。它也是在每个transformer前添加前缀序列来完成目标的。

目前看调优方法并无明显优劣之分,具体下游任务具体分析,目前细微差别只体现在论文的“相近工作比较中”谷歌的prompt tuning吐槽前缀调整包括编码器和解码器网络上的前缀,而提示调整只需要编码器上的提示。P-tuning v2吐槽P-tuning v1针对对复杂任务效果不佳,因为v1只调了input层(但谷歌说的prompt tuning也是只调了input层,但是效果也很显著)。

结论P-tuning v1和谷歌的prompt tuning比较类似:只修改了input层(即给transformer的输入层嵌入了一个可学习的提示向量);prefix-tuning和P-tuning v2比较类似:修改了每一个transformer的input层。

专家的回复

1、前缀长度是否占输入token,不占。大家知道chatglm-6b的输入token是有长度限制的,最初是2048个token,ptuningv2这种加前缀的微调方法,并不会挤占输入token,加入前缀并不会导致可用的输入token长度变短。

2、一个token大概对应1.8个汉字。

这篇关于如何使用自有数据微调ChatGLM-6B的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/411697

相关文章

C++使用栈实现括号匹配的代码详解

《C++使用栈实现括号匹配的代码详解》在编程中,括号匹配是一个常见问题,尤其是在处理数学表达式、编译器解析等任务时,栈是一种非常适合处理此类问题的数据结构,能够精确地管理括号的匹配问题,本文将通过C+... 目录引言问题描述代码讲解代码解析栈的状态表示测试总结引言在编程中,括号匹配是一个常见问题,尤其是在

Java中String字符串使用避坑指南

《Java中String字符串使用避坑指南》Java中的String字符串是我们日常编程中用得最多的类之一,看似简单的String使用,却隐藏着不少“坑”,如果不注意,可能会导致性能问题、意外的错误容... 目录8个避坑点如下:1. 字符串的不可变性:每次修改都创建新对象2. 使用 == 比较字符串,陷阱满

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Linux使用nload监控网络流量的方法

《Linux使用nload监控网络流量的方法》Linux中的nload命令是一个用于实时监控网络流量的工具,它提供了传入和传出流量的可视化表示,帮助用户一目了然地了解网络活动,本文给大家介绍了Linu... 目录简介安装示例用法基础用法指定网络接口限制显示特定流量类型指定刷新率设置流量速率的显示单位监控多个

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

C++ Primer 多维数组的使用

《C++Primer多维数组的使用》本文主要介绍了多维数组在C++语言中的定义、初始化、下标引用以及使用范围for语句处理多维数组的方法,具有一定的参考价值,感兴趣的可以了解一下... 目录多维数组多维数组的初始化多维数组的下标引用使用范围for语句处理多维数组指针和多维数组多维数组严格来说,C++语言没

在 Spring Boot 中使用 @Autowired和 @Bean注解的示例详解

《在SpringBoot中使用@Autowired和@Bean注解的示例详解》本文通过一个示例演示了如何在SpringBoot中使用@Autowired和@Bean注解进行依赖注入和Bean... 目录在 Spring Boot 中使用 @Autowired 和 @Bean 注解示例背景1. 定义 Stud