SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)

本文主要是介绍SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

地址

https://arxiv.org/pdf/2404.14469

核心

这篇论文介绍了一种名为SnapKV的创新方法,旨在提高大型语言模型处理长上下文时的效率和内存利用率。主要贡献包括: 1. 设计实验探索在输出生成过程中注意力特征的模式,发现注意力分配具有一致性,可以提取重要信息。 2. 提出了SnapKV算法,利用观察窗口和投票机制选择每个注意力头的重要键值对,并使用池化进行细粒度聚类。 3. 在多个模型和数据集上评估SnapKV,结果显示其可以大幅压缩键值对缓存,提高解码速度,同时保持模型性能。 总之,SnapKV为长序列输入提供了一种高效压缩键值对缓存的方法,有助于降低内存和计算成本,同时保持了生成质量。

import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math# perform qk calculation and get indices
# this version will not update in inference mode# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)"""batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class KVCluster():def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):# check if prefix phaseassert key_states.shape[-2] == query_states.shape[-2]bsz, num_heads, q_len, head_dim = query_states.shapeif q_len < self.max_capacity_prompt:return key_states, value_stateselse:attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(attn_weights.device)attention_mask = mask[None, None, :, :]attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)if self.pooling == 'avgpool':attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)elif self.pooling == 'maxpool':attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)else:raise ValueError('Pooling method not supported')indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indicesindices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)k_cur = key_states[:, :, -self.window_size:, :]v_cur = value_states[:, :, -self.window_size:, :]key_states = torch.cat([k_past_compress, k_cur], dim = 2)value_states = torch.cat([v_past_compress, v_cur], dim = 2)return key_states, value_states

这段代码实现了一个KVCluster类,用于更新键值对(key-value pairs)。该类具有以下方法:

  • __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):初始化KVCluster对象,可以设置窗口大小、最大容量、卷积核大小和池化方法。

  • reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):重置KVCluster对象的参数。

  • update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):根据输入的键、查询和值的状态以及注意力掩码更新键值对。在查询状态长度小于最大容量时,直接返回原始键值对。否则,根据注意力权重计算窗口内的聚合权重,并根据聚合权重选择top-k的索引。然后将过去的键值对和当前的键值对进行拼接,返回更新后的键值对。

该类主要用于处理键值对的更新,其中关键的部分是计算注意力权重、选择top-k索引和拼接过去和当前的键值对。

根据文档内容,使用SnapKV方法可以按照以下步骤进行:

  1. 确定观测窗口大小:选择输入序列末尾的一部分作为观测窗口,以捕获重要的注意力特征。通常选择窗口大小为32。
  2. 计算注意力权重:对观测窗口的查询和输入序列的前缀进行注意力计算,得到注意力权重矩阵。
  3. 进行投票:对每个注意力头,将观测窗口的注意力权重相加,选出权重最大的前缀位置作为重要特征。
  4. 聚类:对选出的重要特征进行聚类,以保留相邻特征。可以通过1D最大池化实现聚类。
  5. 更新KV缓存:将聚类后的特征与前缀特征拼接,形成新的Key-Value对,并更新KV缓存。这可以将KV缓存的大小压缩到指定值。
  6. 生成:使用更新后的KV缓存进行解码生成。由于KV缓存大小不再随输入序列增长,因此可以显著提高解码速度和内存效率。
  7. 调整参数:根据需要调整观测窗口大小、聚类核大小、KV缓存压缩目标值等参数,以平衡性能和效率。 总的来说,SnapKV通过自动识别输入序列中重要的注意力特征,并仅保留这些特征来压缩KV缓存,实现高效的长序列生成。该方法无需训练,可直接应用于现有模型中。

是的,根据文档中对SnapKV方法的描述,该步骤是在softmax之前进行的。 文档中提到SnapKV包含两个阶段:

  1. 投票阶段:计算观测窗口内查询和前缀的注意力权重,并进行投票,以选择出重要的前缀特征。
  2. 更新和存储阶段:根据投票结果,选择重要特征进行聚类,并拼接这些特征与前缀特征,形成新的KV对,以更新KV缓存。 这一过程发生在softmax之前,也就是在计算注意力权重时进行的。文档中并未明确指出是在softmax之前,但从上下文来看,这一过程发生在注意力权重计算阶段,因此是在softmax之前进行的。
    总之,SnapKV是在计算注意力权重时,通过压缩模型中提示的KV缓存来提高生成效率的,因此是在softmax之前进行的。

这篇关于SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/942919

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

hdu1565(状态压缩)

本人第一道ac的状态压缩dp,这题的数据非常水,很容易过 题意:在n*n的矩阵中选数字使得不存在任意两个数字相邻,求最大值 解题思路: 一、因为在1<<20中有很多状态是无效的,所以第一步是选择有效状态,存到cnt[]数组中 二、dp[i][j]表示到第i行的状态cnt[j]所能得到的最大值,状态转移方程dp[i][j] = max(dp[i][j],dp[i-1][k]) ,其中k满足c

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验