SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)

本文主要是介绍SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

地址

https://arxiv.org/pdf/2404.14469

核心

这篇论文介绍了一种名为SnapKV的创新方法,旨在提高大型语言模型处理长上下文时的效率和内存利用率。主要贡献包括: 1. 设计实验探索在输出生成过程中注意力特征的模式,发现注意力分配具有一致性,可以提取重要信息。 2. 提出了SnapKV算法,利用观察窗口和投票机制选择每个注意力头的重要键值对,并使用池化进行细粒度聚类。 3. 在多个模型和数据集上评估SnapKV,结果显示其可以大幅压缩键值对缓存,提高解码速度,同时保持模型性能。 总之,SnapKV为长序列输入提供了一种高效压缩键值对缓存的方法,有助于降低内存和计算成本,同时保持了生成质量。

import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math# perform qk calculation and get indices
# this version will not update in inference mode# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)"""batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class KVCluster():def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):# check if prefix phaseassert key_states.shape[-2] == query_states.shape[-2]bsz, num_heads, q_len, head_dim = query_states.shapeif q_len < self.max_capacity_prompt:return key_states, value_stateselse:attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(attn_weights.device)attention_mask = mask[None, None, :, :]attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)if self.pooling == 'avgpool':attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)elif self.pooling == 'maxpool':attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)else:raise ValueError('Pooling method not supported')indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indicesindices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)k_cur = key_states[:, :, -self.window_size:, :]v_cur = value_states[:, :, -self.window_size:, :]key_states = torch.cat([k_past_compress, k_cur], dim = 2)value_states = torch.cat([v_past_compress, v_cur], dim = 2)return key_states, value_states

这段代码实现了一个KVCluster类,用于更新键值对(key-value pairs)。该类具有以下方法:

  • __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):初始化KVCluster对象,可以设置窗口大小、最大容量、卷积核大小和池化方法。

  • reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):重置KVCluster对象的参数。

  • update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):根据输入的键、查询和值的状态以及注意力掩码更新键值对。在查询状态长度小于最大容量时,直接返回原始键值对。否则,根据注意力权重计算窗口内的聚合权重,并根据聚合权重选择top-k的索引。然后将过去的键值对和当前的键值对进行拼接,返回更新后的键值对。

该类主要用于处理键值对的更新,其中关键的部分是计算注意力权重、选择top-k索引和拼接过去和当前的键值对。

根据文档内容,使用SnapKV方法可以按照以下步骤进行:

  1. 确定观测窗口大小:选择输入序列末尾的一部分作为观测窗口,以捕获重要的注意力特征。通常选择窗口大小为32。
  2. 计算注意力权重:对观测窗口的查询和输入序列的前缀进行注意力计算,得到注意力权重矩阵。
  3. 进行投票:对每个注意力头,将观测窗口的注意力权重相加,选出权重最大的前缀位置作为重要特征。
  4. 聚类:对选出的重要特征进行聚类,以保留相邻特征。可以通过1D最大池化实现聚类。
  5. 更新KV缓存:将聚类后的特征与前缀特征拼接,形成新的Key-Value对,并更新KV缓存。这可以将KV缓存的大小压缩到指定值。
  6. 生成:使用更新后的KV缓存进行解码生成。由于KV缓存大小不再随输入序列增长,因此可以显著提高解码速度和内存效率。
  7. 调整参数:根据需要调整观测窗口大小、聚类核大小、KV缓存压缩目标值等参数,以平衡性能和效率。 总的来说,SnapKV通过自动识别输入序列中重要的注意力特征,并仅保留这些特征来压缩KV缓存,实现高效的长序列生成。该方法无需训练,可直接应用于现有模型中。

是的,根据文档中对SnapKV方法的描述,该步骤是在softmax之前进行的。 文档中提到SnapKV包含两个阶段:

  1. 投票阶段:计算观测窗口内查询和前缀的注意力权重,并进行投票,以选择出重要的前缀特征。
  2. 更新和存储阶段:根据投票结果,选择重要特征进行聚类,并拼接这些特征与前缀特征,形成新的KV对,以更新KV缓存。 这一过程发生在softmax之前,也就是在计算注意力权重时进行的。文档中并未明确指出是在softmax之前,但从上下文来看,这一过程发生在注意力权重计算阶段,因此是在softmax之前进行的。
    总之,SnapKV是在计算注意力权重时,通过压缩模型中提示的KV缓存来提高生成效率的,因此是在softmax之前进行的。

这篇关于SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/942919

相关文章

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

Java中读取YAML文件配置信息常见问题及解决方法

《Java中读取YAML文件配置信息常见问题及解决方法》:本文主要介绍Java中读取YAML文件配置信息常见问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 目录1 使用Spring Boot的@ConfigurationProperties2. 使用@Valu

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Java 方法重载Overload常见误区及注意事项

《Java方法重载Overload常见误区及注意事项》Java方法重载允许同一类中同名方法通过参数类型、数量、顺序差异实现功能扩展,提升代码灵活性,核心条件为参数列表不同,不涉及返回类型、访问修饰符... 目录Java 方法重载(Overload)详解一、方法重载的核心条件二、构成方法重载的具体情况三、不构

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

Qt使用QSqlDatabase连接MySQL实现增删改查功能

《Qt使用QSqlDatabase连接MySQL实现增删改查功能》这篇文章主要为大家详细介绍了Qt如何使用QSqlDatabase连接MySQL实现增删改查功能,文中的示例代码讲解详细,感兴趣的小伙伴... 目录一、创建数据表二、连接mysql数据库三、封装成一个完整的轻量级 ORM 风格类3.1 表结构