李沐64_注意力机制——自学笔记

2024-04-24 19:20

本文主要是介绍李沐64_注意力机制——自学笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

注意力机制

1.卷积、全连接和池化层都只考虑不随意线索

2.注意力机制则显示的考虑随意线索
(1)随意线索倍称之为查询(query)

(2)每个输入是一个值value,和不随意线索key的对

(3)通过注意力池化层来有偏向性的选择某些输入

总结

注意力机制中,通过query(随意线索)和key(不随意线索)来有偏向性的选择输入

代码实现:注意力汇聚:Nadaraya-Watson 核回归

!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

在这里生成了50个训练样本和50个测试样本。

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数f
(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

批量矩阵乘法

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
tensor([[[ 4.5000]],[[14.5000]]])

使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

将训练数据集变换为键和值用于训练注意力模型。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train, keys, values), y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述

如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')

在这里插入图片描述

这篇关于李沐64_注意力机制——自学笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/932618

相关文章

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.

Redis主从/哨兵机制原理分析

《Redis主从/哨兵机制原理分析》本文介绍了Redis的主从复制和哨兵机制,主从复制实现了数据的热备份和负载均衡,而哨兵机制可以监控Redis集群,实现自动故障转移,哨兵机制通过监控、下线、选举和故... 目录一、主从复制1.1 什么是主从复制1.2 主从复制的作用1.3 主从复制原理1.3.1 全量复制

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1

Java如何通过反射机制获取数据类对象的属性及方法

《Java如何通过反射机制获取数据类对象的属性及方法》文章介绍了如何使用Java反射机制获取类对象的所有属性及其对应的get、set方法,以及如何通过反射机制实现类对象的实例化,感兴趣的朋友跟随小编一... 目录一、通过反射机制获取类对象的所有属性以及相应的get、set方法1.遍历类对象的所有属性2.获取

MySQL中的锁和MVCC机制解读

《MySQL中的锁和MVCC机制解读》MySQL事务、锁和MVCC机制是确保数据库操作原子性、一致性和隔离性的关键,事务必须遵循ACID原则,锁的类型包括表级锁、行级锁和意向锁,MVCC通过非锁定读和... 目录mysql的锁和MVCC机制事务的概念与ACID特性锁的类型及其工作机制锁的粒度与性能影响多版本

Spring使用@Retryable实现自动重试机制

《Spring使用@Retryable实现自动重试机制》在微服务架构中,服务之间的调用可能会因为一些暂时性的错误而失败,例如网络波动、数据库连接超时或第三方服务不可用等,在本文中,我们将介绍如何在Sp... 目录引言1. 什么是 @Retryable?2. 如何在 Spring 中使用 @Retryable

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识