SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)

本文主要是介绍SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

地址

https://arxiv.org/pdf/2404.14469

核心

这篇论文介绍了一种名为SnapKV的创新方法,旨在提高大型语言模型处理长上下文时的效率和内存利用率。主要贡献包括: 1. 设计实验探索在输出生成过程中注意力特征的模式,发现注意力分配具有一致性,可以提取重要信息。 2. 提出了SnapKV算法,利用观察窗口和投票机制选择每个注意力头的重要键值对,并使用池化进行细粒度聚类。 3. 在多个模型和数据集上评估SnapKV,结果显示其可以大幅压缩键值对缓存,提高解码速度,同时保持模型性能。 总之,SnapKV为长序列输入提供了一种高效压缩键值对缓存的方法,有助于降低内存和计算成本,同时保持了生成质量。

import torch
import time
import torch.nn.functional as F
import torch.nn as nn
import math# perform qk calculation and get indices
# this version will not update in inference mode# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:"""This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)"""batch, num_key_value_heads, slen, head_dim = hidden_states.shapeif n_rep == 1:return hidden_stateshidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class KVCluster():def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):self.window_size = window_sizeself.max_capacity_prompt = max_capacity_promptassert self.max_capacity_prompt - self.window_size > 0self.kernel_size = kernel_sizeself.pooling = poolingdef update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):# check if prefix phaseassert key_states.shape[-2] == query_states.shape[-2]bsz, num_heads, q_len, head_dim = query_states.shapeif q_len < self.max_capacity_prompt:return key_states, value_stateselse:attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(attn_weights.device)attention_mask = mask[None, None, :, :]attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_maskattn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)if self.pooling == 'avgpool':attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)elif self.pooling == 'maxpool':attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)else:raise ValueError('Pooling method not supported')indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indicesindices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)k_cur = key_states[:, :, -self.window_size:, :]v_cur = value_states[:, :, -self.window_size:, :]key_states = torch.cat([k_past_compress, k_cur], dim = 2)value_states = torch.cat([v_past_compress, v_cur], dim = 2)return key_states, value_states

这段代码实现了一个KVCluster类,用于更新键值对(key-value pairs)。该类具有以下方法:

  • __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):初始化KVCluster对象,可以设置窗口大小、最大容量、卷积核大小和池化方法。

  • reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):重置KVCluster对象的参数。

  • update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):根据输入的键、查询和值的状态以及注意力掩码更新键值对。在查询状态长度小于最大容量时,直接返回原始键值对。否则,根据注意力权重计算窗口内的聚合权重,并根据聚合权重选择top-k的索引。然后将过去的键值对和当前的键值对进行拼接,返回更新后的键值对。

该类主要用于处理键值对的更新,其中关键的部分是计算注意力权重、选择top-k索引和拼接过去和当前的键值对。

根据文档内容,使用SnapKV方法可以按照以下步骤进行:

  1. 确定观测窗口大小:选择输入序列末尾的一部分作为观测窗口,以捕获重要的注意力特征。通常选择窗口大小为32。
  2. 计算注意力权重:对观测窗口的查询和输入序列的前缀进行注意力计算,得到注意力权重矩阵。
  3. 进行投票:对每个注意力头,将观测窗口的注意力权重相加,选出权重最大的前缀位置作为重要特征。
  4. 聚类:对选出的重要特征进行聚类,以保留相邻特征。可以通过1D最大池化实现聚类。
  5. 更新KV缓存:将聚类后的特征与前缀特征拼接,形成新的Key-Value对,并更新KV缓存。这可以将KV缓存的大小压缩到指定值。
  6. 生成:使用更新后的KV缓存进行解码生成。由于KV缓存大小不再随输入序列增长,因此可以显著提高解码速度和内存效率。
  7. 调整参数:根据需要调整观测窗口大小、聚类核大小、KV缓存压缩目标值等参数,以平衡性能和效率。 总的来说,SnapKV通过自动识别输入序列中重要的注意力特征,并仅保留这些特征来压缩KV缓存,实现高效的长序列生成。该方法无需训练,可直接应用于现有模型中。

是的,根据文档中对SnapKV方法的描述,该步骤是在softmax之前进行的。 文档中提到SnapKV包含两个阶段:

  1. 投票阶段:计算观测窗口内查询和前缀的注意力权重,并进行投票,以选择出重要的前缀特征。
  2. 更新和存储阶段:根据投票结果,选择重要特征进行聚类,并拼接这些特征与前缀特征,形成新的KV对,以更新KV缓存。 这一过程发生在softmax之前,也就是在计算注意力权重时进行的。文档中并未明确指出是在softmax之前,但从上下文来看,这一过程发生在注意力权重计算阶段,因此是在softmax之前进行的。
    总之,SnapKV是在计算注意力权重时,通过压缩模型中提示的KV缓存来提高生成效率的,因此是在softmax之前进行的。

这篇关于SnapKV: LLM Knows What You are Looking for Before Generation(实现超长上下文的压缩方法无需训练)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/942919

相关文章

JavaScript中的reduce方法执行过程、使用场景及进阶用法

《JavaScript中的reduce方法执行过程、使用场景及进阶用法》:本文主要介绍JavaScript中的reduce方法执行过程、使用场景及进阶用法的相关资料,reduce是JavaScri... 目录1. 什么是reduce2. reduce语法2.1 语法2.2 参数说明3. reduce执行过程

C#中读取XML文件的四种常用方法

《C#中读取XML文件的四种常用方法》Xml是Internet环境中跨平台的,依赖于内容的技术,是当前处理结构化文档信息的有力工具,下面我们就来看看C#中读取XML文件的方法都有哪些吧... 目录XML简介格式C#读取XML文件方法使用XmlDocument使用XmlTextReader/XmlTextWr

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

C++初始化数组的几种常见方法(简单易懂)

《C++初始化数组的几种常见方法(简单易懂)》本文介绍了C++中数组的初始化方法,包括一维数组和二维数组的初始化,以及用new动态初始化数组,在C++11及以上版本中,还提供了使用std::array... 目录1、初始化一维数组1.1、使用列表初始化(推荐方式)1.2、初始化部分列表1.3、使用std::

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import

oracle DBMS_SQL.PARSE的使用方法和示例

《oracleDBMS_SQL.PARSE的使用方法和示例》DBMS_SQL是Oracle数据库中的一个强大包,用于动态构建和执行SQL语句,DBMS_SQL.PARSE过程解析SQL语句或PL/S... 目录语法示例注意事项DBMS_SQL 是 oracle 数据库中的一个强大包,它允许动态地构建和执行