SelfAttention|自注意力机制ms简单实现

2024-02-15 20:20

本文主要是介绍SelfAttention|自注意力机制ms简单实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

自注意力机制学习有感

  • 观看b站博主的讲解视频以及跟着他的pytorch代码实现mindspore的自注意力机制:
  • up主讲的很好,推荐入门自注意力机制。
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import context
context.set_context(device_target='Ascend',max_device_memory='1GB') class SelfAttention(nn.Cell):def __init__(self, dim):super(SelfAttention, self).__init__()wq_data = [[1.0, 0], [1., 1.]] # wq权重初始化 超参数wk_data = [[0., 1.], [1., 1.]] # wk权重初始化 超参数wv_data = [[0., 1., 1.], [1., 0., 0.]] # wv权重初始化 超参数self.q = nn.Dense(in_channels=dim, out_channels=2, has_bias=False)self.q.weight.set_data(ms.Tensor(wq_data).T)print("wq value:", self.q.weight.value())self.k = nn.Dense(in_channels = dim, out_channels=2, has_bias=False)self.k.weight.set_data(ms.Tensor(wk_data).T)print('wk value:', self.k.weight.value())self.v = nn.Dense(in_channels=dim, out_channels=3, has_bias=False)# print(self.v.weight.shape)self.v.weight.set_data(ms.Tensor(wv_data).T)print('wv value:',self.v.weight.value())print("*********************" * 2)def construct(self, x):q = self.q(x)print('q value:', q)k = self.k(x)print('k value:', k)v = self.v(x)# xx = x.matmul(ms.Tensor([[0., 1., 1.], [1., 0., 0.]]))print('v value:', v, '\n')print('#################################')x = (q @ k.T)/ms.ops.sqrt(ms.tensor(2.))x = ms.ops.softmax(x) @ vprint("result:", x)x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = ms.Tensor(x)
attn = SelfAttention(2)
attn(x)

结果如下:

wq value: [[1. 1.][0. 1.]]
wk value: [[0. 1.][1. 1.]]
wv value: [[0. 1.][1. 0.][1. 0.]]
******************************************
q value: [[2. 1.][1. 0.][3. 1.][2. 2.]]
k value: [[1. 2.][0. 1.][1. 3.][2. 2.]]
v value: [[1. 1. 1.][0. 1. 1.][1. 2. 2.][2. 0. 0.]] #################################
result: [[1.5499581  0.71284014 0.71284014][1.3395231  0.7726004  0.7726004 ][1.7247156  0.4475609  0.4475609 ][1.4366053  1.         1.        ]]

** 吐槽mindspore说明文档,对ms.nn.Dense的说明太过简单了,有对新手真不友好(对我) **

  • pytorch的文档:
    在这里插入图片描述
  • mindspore的文档:
    在这里插入图片描述
    pytorch有公式,至少提示A的转置有提示。mindspore没有,导致我这步实现的时候输出的结果不对,还是希望mindspore说明问昂也把公式写清楚点。其实mindspore的Dense和pytorch的Linear的公式实现是一样的。
    附上pytorch的实现:
#@title Default title text 
import torch
import torch_npu
import torch.nn as nn
class Self_Attention(torch.nn.Module):def __init__(self, dim):super(Self_Attention, self).__init__() #  其中qkv代表构建好训练好的wq,wk,wv的权重参数;self.scale = 2 ** -0.5self.q = torch.nn.Linear(dim, 2, bias=False) q_list = [[1., 0.],[1., 1.]]self.q.weight.data = torch.Tensor(q_list).Tprint('q value:', self.q.weight.data)self.k = nn.Linear(dim, 2, bias=False)k_list = [[0., 1.], [1., 1.]]self.k.weight.data = torch.Tensor(k_list).Tprint('k value:', self.k.weight.data)self.v = nn.Linear(dim,3,bias=False)v_list = [[0., 1., 1.],[1., 0., 0.]]# print("origin shape:", self.v.weight.data.shape)self.v.weight.data = torch.Tensor(v_list).Tprint('init shape:',self.v.weight.data)def forward(self, x):q = self.q(x)  # 通过训练好的参数生成q参数print("q:", q)k = self.k(x)print("k:", k)v = self.v(x)print("v shape:", v.shape)# Att公式attn = (q.matmul(k.T)) / torch.sqrt(torch.tensor(2.0))print("attn1:", attn)# attn = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(2.0))# print("attn11:", attn)# attn = (q @ k.transpose(-2, -1)) * self.scale# print("attn2:", attn)attn = attn.softmax(dim=-1)print("softmax attn:", attn)# print(attn.shape) # shape[4,4]x = attn @ vprint(x.shape)  #shape[4,3]return x 
x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = torch.Tensor(x)
att = Self_Attention(2)  
att(x)

这篇关于SelfAttention|自注意力机制ms简单实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/712440

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu2289(简单二分)

虽说是简单二分,但是我还是wa死了  题意:已知圆台的体积,求高度 首先要知道圆台体积怎么求:设上下底的半径分别为r1,r2,高为h,V = PI*(r1*r1+r1*r2+r2*r2)*h/3 然后以h进行二分 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#includ

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

usaco 1.3 Prime Cryptarithm(简单哈希表暴搜剪枝)

思路: 1. 用一个 hash[ ] 数组存放输入的数字,令 hash[ tmp ]=1 。 2. 一个自定义函数 check( ) ,检查各位是否为输入的数字。 3. 暴搜。第一行数从 100到999,第二行数从 10到99。 4. 剪枝。 代码: /*ID: who jayLANG: C++TASK: crypt1*/#include<stdio.h>bool h

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount