【HuggingFace Transformers】BertSelfAttention源码解析

2024-08-24 08:44

本文主要是介绍【HuggingFace Transformers】BertSelfAttention源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

BertSelfAttention源码解析

  • 1. BertSelfAttention类 介绍
    • 1.1 关键组件
    • 1.2 主要方法
  • 2. BertSelfAttention类 源码解析(核心简版)
  • 3. BertSelfAttention类 源码解析

1. BertSelfAttention类 介绍

BertSelfAttention 类是 BERT 模型的核心组件之一,主要负责实现多头自注意力机制。通过注意力机制,模型可以捕捉到输入序列中各个位置之间的依赖关系。以下是对 BertSelfAttention 类的详细介绍:

1.1 关键组件

  • num_attention_heads:注意力头的数量。多头注意力机制通过使用多个注意力头来增强模型的表达能力,每个头在不同的子空间中学习注意力模式。

  • attention_head_size:每个注意力头的维度。它等于 hidden_size 除以 num_attention_heads

  • all_head_size:所有注意力头的总维度。它等于 attention_head_size 乘以 num_attention_heads,通常与 hidden_size 相等。

  • query, key, value:线性变换层,用于将输入序列映射到查询(Q)、键(K)和值(V)表示。这些是计算注意力权重的基础。

  • dropout:用于防止过拟合的 Dropout 层,应用在计算出的注意力权重上。

  • position_embedding_type:位置嵌入的类型,BERT 主要使用绝对位置嵌入,但该类也支持相对位置嵌入(如 relative_keyrelative_key_query)。

  • distance_embedding:在使用相对位置嵌入时,模型学习的相对位置距离嵌入。

  • is_decoder:指示是否为解码器模型的一部分。这在解码器-编码器架构(如 Transformer)中非常重要。

1.2 主要方法

  • __init__初始化方法,配置并创建注意力层的各个组件。它会检查输入的 hidden_size 是否能被 num_attention_heads 整除,以确保每个注意力头处理的维度是均匀的。

  • transpose_for_scores:将输入张量的形状从 [batch_size, seq_length, hidden_size] 转换为 [batch_size, num_attention_heads, seq_length, attention_head_size],以便进行多头并行计算。

  • forward前向传播方法,执行自注意力计算,计算过程参考公式。具体步骤包括:
    (1) 输入的 hidden_states 通过 query, key, value 层进行线性变换,生成 Q, K, V
    (2) 计算 QK 的点积来生成注意力分数
    (3) 对分数进行缩放,并应用 softmax 生成注意力权重
    (4) 将注意力权重V 相乘生成上下文向量
    (5) 如果需要,返回注意力权重上下文向量

2. BertSelfAttention类 源码解析(核心简版)

这里我们设定配置为:

position_embedding_type="absolute"
is_decoder = False
encoder_hidden_states = None
past_key_value = None

即核心简化版的BertSelfAttention类为:

# -*- coding: utf-8 -*-
# @time: 2024/8/23 18:46import torch
import mathfrom torch import nn
from typing import Optional, Tupleclass BertSelfAttention(nn.Module):def __init__(self, config, position_embedding_type=None):super().__init__()"""hidden size需要能被attention头的数量整除,以确保每个头能处理hidden size的相等部分。例如,如果hidden_size是768,num_attention_heads是12,那么768 % 12等于0,这意味着配置是有效的。"""# ----------------------------------------------检查配置--------------------------------------------------------# 如果 hidden_size 不能被 num_attention_heads 整除,并且 config 对象没有 embedding_size 属性, 引发 ValueError,说明 hidden_size 和 num_attention_heads 不兼容if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "f"heads ({config.num_attention_heads})")# 1. 获取注意力头数量(num_attention_heads), 每个注意力头的大小(attention_head_size), 所有注意力头的大小(all_head_size)# 设置注意力头的数量为配置中的num_attention_heads,决定了有多少个并行的注意力头,例如:12self.num_attention_heads = config.num_attention_heads# 计算每个注意力头的尺寸,即hidden_size除以注意力头的数量,决定了每个注意力头处理的特征维度大小,例如:64self.attention_head_size = int(config.hidden_size / config.num_attention_heads)# 计算所有注意力头的总尺寸,即注意力头数量乘以每个头的尺寸,是所有注意力头的总特征维度大小,通常等于 hidden_size,例如:768self.all_head_size = self.num_attention_heads * self.attention_head_size# 2. 定义query, key, value 线性变换层, dropout层, position_embedding_type, (max_position_embeddings, distance_embedding), is_decoder# 定义query,key,value线性变换层,将hidden_size映射到all_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size)self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)# 3. 定义dropout层,用于注意力概率的dropout,防止过拟合self.dropout = nn.Dropout(config.attention_probs_dropout_prob)# 4. 设置位置嵌入类型,如果没有提供则从配置中获取,默认为'absolute'self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")# 如果位置嵌入类型是 'relative_key'或'relative_key_query', 设置最大位置嵌入数量为配置中的max_position_embeddings 以及 距离嵌入if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":self.max_position_embeddings = config.max_position_embeddingsself.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)# 5. 设置是否为解码器self.is_decoder = config.is_decoder# 转换张量维度方法def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:# 获取new_x_shape,保持除最后一维外的所有维度不变,然后将最后一维拆分为num_attention_heads和attention_head_size的维度new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(new_x_shape)  # 将输入张量x重塑为new_x_shape# 将张量维度从 (batch_size, seq_length, num_attention_heads, attention_head_size) 转置为 (batch_size, num_attention_heads, seq_length, attention_head_size)return x.permute(0, 2, 1, 3)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.FloatTensor] = None,head_mask: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.FloatTensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,output_attentions: Optional[bool] = False,) -> Tuple[torch.Tensor]:# 1. 获取 key, value, query 层mixed_query_layer = self.query(hidden_states)key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)# 2. 计算 query 和 key 的点积,得到注意力得分attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# 3. 归一化 attention 得分:对注意力得分进行缩放,并应用注意力掩码,例如:sqrt(64)attention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:attention_scores = attention_scores + attention_mask# 4. 计算注意力概率:使用 softmax 计算注意力权重,并应用 dropoutattention_probs = nn.functional.softmax(attention_scores, dim=-1)attention_probs = self.dropout(attention_probs)# 5. 应用头部掩码:如果有头部掩码,应用头部掩码if head_mask is not None:attention_probs = attention_probs * head_mask# 6. 计算上下文层:计算 attention_probs 和 value 的点积,得到上下文层,并进行变形。context_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()  # 确保tensor在内存中是连续的new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(new_context_layer_shape)# 7.返回输出:根据 output_attentions 参数,决定是否返回注意力权重。如果是解码器,还要返回缓存的键值对outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)return outputs

3. BertSelfAttention类 源码解析

# -*- coding: utf-8 -*-
# @author: yyj
# @time: 2024/7/15 14:28import torch
import mathfrom torch import nn
from typing import Optional, Tupleclass BertSelfAttention(nn.Module):def __init__(self, config, position_embedding_type=None):super().__init__()"""hidden size需要能被attention头的数量整除,以确保每个头能处理hidden size的相等部分。例如,如果hidden_size是768,num_attention_heads是12,那么768 % 12等于0,这意味着配置是有效的。"""# ----------------------------------------------检查配置--------------------------------------------------------# 如果 hidden_size 不能被 num_attention_heads 整除,并且 config 对象没有 embedding_size 属性, 引发 ValueError,说明 hidden_size 和 num_attention_heads 不兼容if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "f"heads ({config.num_attention_heads})")# 1. 获取注意力头数量(num_attention_heads), 每个注意力头的大小(attention_head_size), 所有注意力头的大小(all_head_size)# 设置注意力头的数量为配置中的num_attention_heads,决定了有多少个并行的注意力头,例如:12self.num_attention_heads = config.num_attention_heads# 计算每个注意力头的尺寸,即hidden_size除以注意力头的数量,决定了每个注意力头处理的特征维度大小,例如:64self.attention_head_size = int(config.hidden_size / config.num_attention_heads)# 计算所有注意力头的总尺寸,即注意力头数量乘以每个头的尺寸,是所有注意力头的总特征维度大小,通常等于 hidden_size,例如:768self.all_head_size = self.num_attention_heads * self.attention_head_size# 2. 定义query, key, value 线性变换层, dropout层, position_embedding_type, (max_position_embeddings, distance_embedding), is_decoder# 定义query,key,value线性变换层,将hidden_size映射到all_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size)self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)# 3. 定义dropout层,用于注意力概率的dropout,防止过拟合self.dropout = nn.Dropout(config.attention_probs_dropout_prob)# 4. 设置位置嵌入类型,如果没有提供则从配置中获取,默认为'absolute'self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")# 如果位置嵌入类型是 'relative_key'或'relative_key_query', 设置最大位置嵌入数量为配置中的max_position_embeddings 以及 距离嵌入if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":self.max_position_embeddings = config.max_position_embeddingsself.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)# 5. 设置是否为解码器self.is_decoder = config.is_decoder# 转换张量维度方法def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:# 获取new_x_shape,保持除最后一维外的所有维度不变,然后将最后一维拆分为num_attention_heads和attention_head_size的维度new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(new_x_shape)  # 将输入张量x重塑为new_x_shape# 将张量维度从 (batch_size, seq_length, num_attention_heads, attention_head_size) 转置为 (batch_size, num_attention_heads, seq_length, attention_head_size)return x.permute(0, 2, 1, 3)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.FloatTensor] = None,head_mask: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.FloatTensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,output_attentions: Optional[bool] = False,) -> Tuple[torch.Tensor]:# -------------1. 计算Query层-----------mixed_query_layer = self.query(hidden_states)# If this is instantiated as a cross-attention module, the keys# and values come from an encoder; the attention mask needs to be# such that the encoder's padding tokens are not attended to.# 如果这是作为交叉注意力模块实例化的,键和值来自编码器;注意力掩码需要确保编码器的填充标记不会被关注到。# --------2. 根据是否为交叉注意力和是否有缓存的键值对,来决定如何获取 key 和 value 层,并设置 attention_mask---------is_cross_attention = encoder_hidden_states is not Noneif is_cross_attention and past_key_value is not None:# reuse k,v, cross_attentionskey_layer = past_key_value[0]value_layer = past_key_value[1]attention_mask = encoder_attention_maskelif is_cross_attention:  # 如果提供了 encoder_hidden_states,使用编码器隐藏状态计算键和值key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))attention_mask = encoder_attention_maskelif past_key_value is not None:  # 如果有 past_key_value,则将旧的键和值与当前的键和值拼接key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))key_layer = torch.cat([past_key_value[0], key_layer], dim=2)value_layer = torch.cat([past_key_value[1], value_layer], dim=2)else:  # 直接使用当前的隐藏状态计算键和值key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))# -----------------1. 转置 Query 层: 将 query 层转置以适应多头注意力的格式-----------------query_layer = self.transpose_for_scores(mixed_query_layer)# ----------------2. 如果是解码器并且有缓存键值对,则将当前的 key 和 value 层进行缓存-------------use_cache = past_key_value is not Noneif self.is_decoder:# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.# Further calls to cross_attention layer can then reuse all cross-attention# key/value_states (first "if" case)# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of# all previous decoder key/value_states. Further calls to uni-directional self-attention# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)# if encoder bi-directional self-attention `past_key_value` is always `None`past_key_value = (key_layer, value_layer)# Take the dot product between "query" and "key" to get the raw attention scores.# -----------------5. 计算 query 和 key 的点积,得到注意力得分--------------attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# 3. 相对位置的嵌入:如果使用相对位置嵌入,根据相对位置计算注意力得分并加到 attention_scores 上# 相对位置编码允许模型捕捉输入序列中标记之间的相对位置信息,而不是绝对位置信息。# 具体来说,这段代码通过计算查询和键之间的相对距离,然后使用这些距离来调整注意力分数。if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":query_length, key_length = query_layer.shape[2], key_layer.shape[2]# position_ids_l 是 query 层的position_id# position_ids_r 是 key 层的position_idif use_cache:position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)else:position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)distance = position_ids_l - position_ids_r  # 计算query位置id和key位置id之间的相对距离"""distance: 以 query_length = 6, key_length = 6为例:position_ids_l = [[0], [1], [2], [3], [4], [5]]position_ids_r = [[0, 1, 2, 3, 4, 5]]distance = position_ids_l - position_ids_r# 计算后的 distance 张量:distance = [[ 0, -1, -2, -3, -4, -5], [ 1,  0, -1, -2, -3, -4], [ 2,  1,  0, -1, -2, -3], [ 3,  2,  1,  0, -1, -2], [ 4,  3,  2,  1,  0, -1], [ 5,  4,  3,  2,  1,  0]]"""positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility# positional_embedding的shape: torch.Size([seq_length, seq_length, hidden_dim / num_head])# 如果 position_embedding_type 是 relative_key,计算查询层与相对位置嵌入的内积,得到相对位置得分,然后加到注意力得分上。# einsum 是爱因斯坦求和约定(Einstein summation convention)# 详解参考:https://blog.csdn.net/weixin_47936614/article/details/141468836if self.position_embedding_type == "relative_key":relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)attention_scores = attention_scores + relative_position_scoreselif self.position_embedding_type == "relative_key_query":relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key# 4. 归一化 attention 得分:对注意力得分进行缩放,并应用注意力掩码,例如:sqrt(64)attention_scores = attention_scores / math.sqrt(self.attention_head_size)if attention_mask is not None:# Apply the attention mask is (precomputed for all layers in BertModel forward() function)# 应用注意力掩码(在BertModel的forward()函数中预先计算用于所有层)attention_scores = attention_scores + attention_mask# 5. 计算注意力概率:使用 softmax 计算注意力权重,并应用 dropout# Normalize the attention scores to probabilities.attention_probs = nn.functional.softmax(attention_scores, dim=-1)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)# Mask heads if we want to# 6. 应用头部掩码:如果有头部掩码,应用头部掩码if head_mask is not None:attention_probs = attention_probs * head_mask# 7. 计算上下文层:计算 attention_probs 和 value 的点积,得到上下文层,并进行变形。context_layer = torch.matmul(attention_probs, value_layer)# 对context_layer进行维度转换,使其符合预期的顺序# 这里的permute操作将tensor的维度从 (batch_size, num_heads, seq_length, head_dim) 转换为 (batch_size, seq_length, num_heads, head_dim)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()  # 确保tensor在内存中是连续的# 创建新的context_layer形状,将最后两个维度合并成一个# new_context_layer_shape 的形状为 (batch_size, seq_length, all_head_size),其中all_head_size = num_heads * head_dimnew_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)# 重新调整context_layer的view,使其符合新的形状context_layer = context_layer.view(new_context_layer_shape)# 8.返回输出:根据 output_attentions 参数,决定是否返回注意力权重。如果是解码器,还要返回缓存的键值对outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)if self.is_decoder:outputs = outputs + (past_key_value,)return outputs

这篇关于【HuggingFace Transformers】BertSelfAttention源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1101998

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get

工厂ERP管理系统实现源码(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程,确保采购过程的透明和高效。仓库管理方面,实现库存的精准管理,包括入库、出库、盘点等操作,确保库存数据的准确性和实时性。 生产管理模块则涵盖了生产计划制定、物料需求计划、

OWASP十大安全漏洞解析

OWASP(开放式Web应用程序安全项目)发布的“十大安全漏洞”列表是Web应用程序安全领域的权威指南,它总结了Web应用程序中最常见、最危险的安全隐患。以下是对OWASP十大安全漏洞的详细解析: 1. 注入漏洞(Injection) 描述:攻击者通过在应用程序的输入数据中插入恶意代码,从而控制应用程序的行为。常见的注入类型包括SQL注入、OS命令注入、LDAP注入等。 影响:可能导致数据泄

从状态管理到性能优化:全面解析 Android Compose

文章目录 引言一、Android Compose基本概念1.1 什么是Android Compose?1.2 Compose的优势1.3 如何在项目中使用Compose 二、Compose中的状态管理2.1 状态管理的重要性2.2 Compose中的状态和数据流2.3 使用State和MutableState处理状态2.4 通过ViewModel进行状态管理 三、Compose中的列表和滚动

Spring 源码解读:自定义实现Bean定义的注册与解析

引言 在Spring框架中,Bean的注册与解析是整个依赖注入流程的核心步骤。通过Bean定义,Spring容器知道如何创建、配置和管理每个Bean实例。本篇文章将通过实现一个简化版的Bean定义注册与解析机制,帮助你理解Spring框架背后的设计逻辑。我们还将对比Spring中的BeanDefinition和BeanDefinitionRegistry,以全面掌握Bean注册和解析的核心原理。

CSP 2023 提高级第一轮 CSP-S 2023初试题 完善程序第二题解析 未完

一、题目阅读 (最大值之和)给定整数序列 a0,⋯,an−1,求该序列所有非空连续子序列的最大值之和。上述参数满足 1≤n≤105 和 1≤ai≤108。 一个序列的非空连续子序列可以用两个下标 ll 和 rr(其中0≤l≤r<n0≤l≤r<n)表示,对应的序列为 al,al+1,⋯,ar​。两个非空连续子序列不同,当且仅当下标不同。 例如,当原序列为 [1,2,1,2] 时,要计算子序列 [