【LLM】PISSA:一种高效的微调方法

2024-06-22 16:52

本文主要是介绍【LLM】PISSA:一种高效的微调方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

介绍PISSA前,先简单过一下LLMs微调经常采用的LoRA(Low-Rank Adaptation)微调的方法,LoRA 假设权重更新的过程中有一个较低的本征秩,对于预训练的权重参数矩阵 W 0 ∈ R d × k W_0 ∈ R^{d×k} W0Rd×k,( d d d 为上一层输出维度, k k k 为下一层输入维度),使用低秩分解来表示其更新:

在训练过程中, W 0 W_0 W0冻结不更新, A A A B B B 包含可训练参数。

则 LoRA 的前向传递函数为:

初始化时,常将低秩矩阵 A A A高斯初始化, B B B初始化为0。这样在训练初期AB接近于零,不会影响模型的输出。

LoRA微调架构

PISSA

三种微调方式架构

从图中可以看出,PISSA和LoRA主要的区别是初始化方式不同:

  • LoRA:使用随机高斯分布初始化 A A A B B B初始化为零。过程中只训练了低秩矩阵 A A A B B B
  • PISSA:同样基于低秩特性的假设,但PISSA不是去近似 ∆ W ∆W W,而是直接对 W W W进行操作。PiSSA使用奇异值分解(SVD)将 W W W分解为两个矩阵 A A A B B B的乘积加上一个残差矩阵 W r e s W^{res} Wres A A A B B B使用 W W W的主奇异值和奇异向量进行初始化,而 W r e s W^{res} Wres则使用剩余的奇异值和奇异向量初始化,并在微调过程中保持不变。也就能保证初始化时和基座模型一样。因此,和LoRA一样,PISSA的训练中也只训练了低秩矩阵 A A A B B B,而 W r e s W^{res} Wres保持冻结

初始化A和B矩阵:使用主要的奇异值和奇异向量初始化两个可训练的矩阵:


构建残差矩阵 W r e s W^{res} Wres:使用残差奇异值和奇异向量构建残差矩阵:

实验

PISSA微调

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_datasetmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(# init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model.init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds.
)
peft_model = get_peft_model(model, lora_config)peft_model.print_trainable_parameters()dataset = load_dataset("imdb", split="train[:1%]")trainer = SFTTrainer(model=peft_model,train_dataset=dataset,dataset_text_field="text",max_seq_length=128,tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("pissa-llama-2-7b")

pissa加载

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# Performs SVD again to initialize the residual model and loads the state_dict of the fine-tuned PiSSA modules.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b")

将 PiSSA 转换为 LoRA

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# No SVD is performed during this step, and the base model remains unaltered.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora")

总结

PiSSA是一种高效的微调方法,它通过奇异值分解提取大型语言模型中的关键参数,并仅对这些参数进行更新,以实现与全参数微调相似的性能,同时显著降低计算成本和参数数量。

参考文献

  • PISSA: PRINCIPAL SINGULAR VALUES AND SINGULAR VECTORS ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2404.02948
  • https://github.com/GraphPKU/PiSSA
  • LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS,https://arxiv.org/pdf/2106.09685
  • https://github.com/microsoft/LoRA

这篇关于【LLM】PISSA:一种高效的微调方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1084864

相关文章

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

Java判断多个时间段是否重合的方法小结

《Java判断多个时间段是否重合的方法小结》这篇文章主要为大家详细介绍了Java中判断多个时间段是否重合的方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录判断多个时间段是否有间隔判断时间段集合是否与某时间段重合判断多个时间段是否有间隔实体类内容public class D

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

IDEA编译报错“java: 常量字符串过长”的原因及解决方法

《IDEA编译报错“java:常量字符串过长”的原因及解决方法》今天在开发过程中,由于尝试将一个文件的Base64字符串设置为常量,结果导致IDEA编译的时候出现了如下报错java:常量字符串过长,... 目录一、问题描述二、问题原因2.1 理论角度2.2 源码角度三、解决方案解决方案①:StringBui

Linux使用nload监控网络流量的方法

《Linux使用nload监控网络流量的方法》Linux中的nload命令是一个用于实时监控网络流量的工具,它提供了传入和传出流量的可视化表示,帮助用户一目了然地了解网络活动,本文给大家介绍了Linu... 目录简介安装示例用法基础用法指定网络接口限制显示特定流量类型指定刷新率设置流量速率的显示单位监控多个

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

C#中读取XML文件的四种常用方法

《C#中读取XML文件的四种常用方法》Xml是Internet环境中跨平台的,依赖于内容的技术,是当前处理结构化文档信息的有力工具,下面我们就来看看C#中读取XML文件的方法都有哪些吧... 目录XML简介格式C#读取XML文件方法使用XmlDocument使用XmlTextReader/XmlTextWr

C++初始化数组的几种常见方法(简单易懂)

《C++初始化数组的几种常见方法(简单易懂)》本文介绍了C++中数组的初始化方法,包括一维数组和二维数组的初始化,以及用new动态初始化数组,在C++11及以上版本中,还提供了使用std::array... 目录1、初始化一维数组1.1、使用列表初始化(推荐方式)1.2、初始化部分列表1.3、使用std::

oracle DBMS_SQL.PARSE的使用方法和示例

《oracleDBMS_SQL.PARSE的使用方法和示例》DBMS_SQL是Oracle数据库中的一个强大包,用于动态构建和执行SQL语句,DBMS_SQL.PARSE过程解析SQL语句或PL/S... 目录语法示例注意事项DBMS_SQL 是 oracle 数据库中的一个强大包,它允许动态地构建和执行