本文主要是介绍【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
BertIntermediate 和 BertPooler源码解析
- 1. 介绍
- 1.1 位置与功能
- 1.2 相似点与不同点
- 2. 源码解析
- 2.1 BertIntermediate 源码解析
- 2.2 BertPooler 源码解析
1. 介绍
1.1 位置与功能
(1) BertIntermediate
- 位置:位于 BertLayer 的注意力层(BertSelfAttention)和输出层(BertOutput)之间。
- 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。
(2) BertPooler
- 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
- 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。
1.2 相似点与不同点
(1) 相似点
- 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
- 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。
(2) 不同点
- 应用层次:
BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
BertPooler 只在模型的最后一层作用,用于提取全局特征。 - 功能目标:
BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。
2. 源码解析
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
2.1 BertIntermediate 源码解析
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torchfrom torch import nn
from transformers.activations import ACT2FNclass BertIntermediate(nn.Module):def __init__(self, config):super().__init__()# 全连接层,将 hidden_size 映射到 intermediate_sizeself.dense = nn.Linear(config.hidden_size, config.intermediate_size)# 根据 config.hidden_act 定义激活函数if isinstance(config.hidden_act, str):self.intermediate_act_fn = ACT2FN[config.hidden_act]else:self.intermediate_act_fn = config.hidden_actdef forward(self, hidden_states: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states) # 线性变换hidden_states = self.intermediate_act_fn(hidden_states) # 激活函数return hidden_states
2.2 BertPooler 源码解析
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41import torchfrom torch import nnclass BertPooler(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 全连接层,将 hidden_size 映射回 hidden_sizeself.activation = nn.Tanh() # 激活函数为 Tanh 函数def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:# We "pool" the model by simply taking the hidden state corresponding# to the first token.# 提取序列中的第一个 token,也就是 [CLS] 的 hidden statefirst_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor) # 线性变换pooled_output = self.activation(pooled_output) # 激活函数return pooled_output
这篇关于【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!