大模型面试准备(五):图解 Transformer 最关键模块 MHA

2024-03-27 03:04

本文主要是介绍大模型面试准备(五):图解 Transformer 最关键模块 MHA,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何备战、面试常考点分享等热门话题进行了深入的讨论。


合集在这里:《大模型面试宝典》(2024版) 正式发布!


Transformer 原始论文中的模型结构如下图所示:
图片

上一篇文章讲解了 Transformer 的关键模块 Positional Encoding(大家可以自行翻阅),本篇文章讲解一下 Transformer 的最重要模块 Multi-Head Attention(MHA),毕竟 Transformer 的论文名称就叫 《Attention Is All You Need》。

Transformer 中的 Multi-Head Attention 可以细分为3种,Multi-Head Self-Attention(对应上图左侧Multi-Head Attention模块),Multi-Head Cross-Attention(对应上图右上Multi-Head Attention模块),Masked Multi-Head Self-Attention(对应上图右下Masked Multi-Head Attention模块)。

其中 Self 和 Cross 的区分是对应的 Q和 K、 V是否来自相同的输入。是否Mask的区分是是否需要看见全部输入和预测的输出,Encoder需要看见全部的输入问题,所以不能Mask;而Decoder是预测输出,当前预测只能看见之前的全部预测,不能看见之后的预测,所以需要Mask。

本篇文章主要通过图解的方式对 Multi-Head Attention 的核心思想和计算过程做讲解,喜欢本文记得收藏、点赞、关注。技术和面试交流,文末加入我们

MHA核心思想

在这里插入图片描述

MHA过程图解

注意力计算公式如下:

在这里插入图片描述

图示过程图下:

图片

多头注意力

MHA通过多个头的方式,可以增强自注意力机制聚合上下文信息的能力,以关注上下文的不同侧面,作用类似于CNN的多个卷积核。下面我们就通过一张图来完成MHA的解析:

图片

在这里插入图片描述

单头注意力

知道了多头注意力的实现方式后,那如果是通过单头注意力完成同样的计算,矩阵形式是什么样的呢?下面我还是以一图胜千言的方式来回答这个问题:

图片通过单头注意力的比较,相信大家对多头注意力(MHA)应该有了更好的理解。我们可以发现多头注意力就是将一个单头进行了切分计算,最后又将结果进行了合并,整个过程中的整体维度和计算量基本是不变的,但提升了模型的学习能力。

最后附上一份MHA的实现和Transformer的构建代码:

import torch
import torch.nn as nn# 定义多头自注意力层
class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads):super(MultiHeadAttention, self).__init__()self.n_heads = n_heads  # 多头注意力的头数self.d_model = d_model  # 输入维度(模型的总维度)self.head_dim = d_model // n_heads  # 每个注意力头的维度assert self.head_dim * n_heads == d_model, "d_model必须能够被n_heads整除"  # 断言,确保d_model可以被n_heads整除# 线性变换矩阵,用于将输入向量映射到查询、键和值空间self.wq = nn.Linear(d_model, d_model)  # 查询(Query)的线性变换self.wk = nn.Linear(d_model, d_model)  # 键(Key)的线性变换self.wv = nn.Linear(d_model, d_model)  # 值(Value)的线性变换# 最终输出的线性变换,将多头注意力结果合并回原始维度self.fc_out = nn.Linear(d_model, d_model)  # 输出的线性变换def forward(self, query, key, value, mask):# 将嵌入向量分成不同的头query = query.view(query.shape[0], -1, self.n_heads, self.head_dim)key = key.view(key.shape[0], -1, self.n_heads, self.head_dim)value = value.view(value.shape[0], -1, self.n_heads, self.head_dim)# 转置以获得维度 batch_size, self.n_heads, seq_len, self.head_dimquery = query.transpose(1, 2)key = key.transpose(1, 2)value = value.transpose(1, 2)# 计算注意力得分scores = torch.matmul(query, key.transpose(-2, -1)) / self.head_dimif mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attention = torch.nn.functional.softmax(scores, dim=-1)out = torch.matmul(attention, value)# 重塑以恢复原始输入形状out = out.transpose(1, 2).contiguous().view(query.shape[0], -1, self.d_model)out = self.fc_out(out)return out# 定义Transformer编码器层
class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, n_heads, dim_feedforward, dropout):super(TransformerEncoderLayer, self).__init__()# 多头自注意力层,接收d_model维度输入,使用n_heads个注意力头self.self_attn = MultiHeadAttention(d_model, n_heads)# 第一个全连接层,将d_model维度映射到dim_feedforward维度self.linear1 = nn.Linear(d_model, dim_feedforward)# 第二个全连接层,将dim_feedforward维度映射回d_model维度self.linear2 = nn.Linear(dim_feedforward, d_model)# 用于随机丢弃部分神经元,以减少过拟合self.dropout = nn.Dropout(dropout)# 第一个层归一化层,用于归一化第一个全连接层的输出self.norm1 = nn.LayerNorm(d_model)# 第二个层归一化层,用于归一化第二个全连接层的输出self.norm2 = nn.LayerNorm(d_model)def forward(self, src, src_mask):# 使用多头自注意力层处理输入src,同时提供src_mask以屏蔽不需要考虑的位置src2 = self.self_attn(src, src, src, src_mask)# 残差连接和丢弃:将自注意力层的输出与原始输入相加,并应用丢弃src = src + self.dropout(src2)# 应用第一个层归一化src = self.norm1(src)# 经过第一个全连接层,再经过激活函数ReLU,然后进行丢弃src2 = self.linear2(self.dropout(torch.nn.functional.relu(self.linear1(src))))# 残差连接和丢弃:将全连接层的输出与之前的输出相加,并再次应用丢弃src = src + self.dropout(src2)# 应用第二个层归一化src = self.norm2(src)# 返回编码器层的输出return src# 实例化模型
vocab_size = 10000  # 词汇表大小(根据实际情况调整)
d_model = 512  # 模型的维度
n_heads = 8  # 多头自注意力的头数
num_encoder_layers = 6  # 编码器层的数量
dim_feedforward = 2048  # 全连接层的隐藏层维度
max_seq_length = 100  # 最大序列长度
dropout = 0.1  # 丢弃率# 创建Transformer模型实例
model = Transformer(vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward, max_seq_length, dropout)

最后的最后再贴上一张非常不错的 Transformer 手绘吧!

在这里插入图片描述

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了算法岗技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

用通俗易懂方式讲解系列

  • 《大模型面试宝典》(2024版) 正式发布!
  • 《大模型实战宝典》(2024版)正式发布!
  • 大模型面试准备(一):LLM主流结构和训练目标、构建流程
  • 大模型面试准备(二):LLM容易被忽略的Tokenizer与Embedding
  • 大模型面试准备(三):聊一聊大模型的幻觉问题
  • 大模型面试准备(四):大模型面试必会的位置编码(绝对位置编码sinusoidal,旋转位置编码RoPE,以及相对位置编码ALiBi)

参考文献:

参考资料:
[1] https://jalammar.github.io/illustrated-transformer/
[2] https://zhuanlan.zhihu.com/p/264468193
[3] https://zhuanlan.zhihu.com/p/662777298

这篇关于大模型面试准备(五):图解 Transformer 最关键模块 MHA的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/850729

相关文章

Python中构建终端应用界面利器Blessed模块的使用

《Python中构建终端应用界面利器Blessed模块的使用》Blessed库作为一个轻量级且功能强大的解决方案,开始在开发者中赢得口碑,今天,我们就一起来探索一下它是如何让终端UI开发变得轻松而高... 目录一、安装与配置:简单、快速、无障碍二、基本功能:从彩色文本到动态交互1. 显示基本内容2. 创建链

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

python中的与时间相关的模块应用场景分析

《python中的与时间相关的模块应用场景分析》本文介绍了Python中与时间相关的几个重要模块:`time`、`datetime`、`calendar`、`timeit`、`pytz`和`dateu... 目录1. time 模块2. datetime 模块3. calendar 模块4. timeit

Python模块导入的几种方法实现

《Python模块导入的几种方法实现》本文主要介绍了Python模块导入的几种方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录一、什么是模块?二、模块导入的基本方法1. 使用import整个模块2.使用from ... i

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G