SelfAttention|自注意力机制ms简单实现

2024-02-15 20:20

本文主要是介绍SelfAttention|自注意力机制ms简单实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

自注意力机制学习有感

  • 观看b站博主的讲解视频以及跟着他的pytorch代码实现mindspore的自注意力机制:
  • up主讲的很好,推荐入门自注意力机制。
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import context
context.set_context(device_target='Ascend',max_device_memory='1GB') class SelfAttention(nn.Cell):def __init__(self, dim):super(SelfAttention, self).__init__()wq_data = [[1.0, 0], [1., 1.]] # wq权重初始化 超参数wk_data = [[0., 1.], [1., 1.]] # wk权重初始化 超参数wv_data = [[0., 1., 1.], [1., 0., 0.]] # wv权重初始化 超参数self.q = nn.Dense(in_channels=dim, out_channels=2, has_bias=False)self.q.weight.set_data(ms.Tensor(wq_data).T)print("wq value:", self.q.weight.value())self.k = nn.Dense(in_channels = dim, out_channels=2, has_bias=False)self.k.weight.set_data(ms.Tensor(wk_data).T)print('wk value:', self.k.weight.value())self.v = nn.Dense(in_channels=dim, out_channels=3, has_bias=False)# print(self.v.weight.shape)self.v.weight.set_data(ms.Tensor(wv_data).T)print('wv value:',self.v.weight.value())print("*********************" * 2)def construct(self, x):q = self.q(x)print('q value:', q)k = self.k(x)print('k value:', k)v = self.v(x)# xx = x.matmul(ms.Tensor([[0., 1., 1.], [1., 0., 0.]]))print('v value:', v, '\n')print('#################################')x = (q @ k.T)/ms.ops.sqrt(ms.tensor(2.))x = ms.ops.softmax(x) @ vprint("result:", x)x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = ms.Tensor(x)
attn = SelfAttention(2)
attn(x)

结果如下:

wq value: [[1. 1.][0. 1.]]
wk value: [[0. 1.][1. 1.]]
wv value: [[0. 1.][1. 0.][1. 0.]]
******************************************
q value: [[2. 1.][1. 0.][3. 1.][2. 2.]]
k value: [[1. 2.][0. 1.][1. 3.][2. 2.]]
v value: [[1. 1. 1.][0. 1. 1.][1. 2. 2.][2. 0. 0.]] #################################
result: [[1.5499581  0.71284014 0.71284014][1.3395231  0.7726004  0.7726004 ][1.7247156  0.4475609  0.4475609 ][1.4366053  1.         1.        ]]

** 吐槽mindspore说明文档,对ms.nn.Dense的说明太过简单了,有对新手真不友好(对我) **

  • pytorch的文档:
    在这里插入图片描述
  • mindspore的文档:
    在这里插入图片描述
    pytorch有公式,至少提示A的转置有提示。mindspore没有,导致我这步实现的时候输出的结果不对,还是希望mindspore说明问昂也把公式写清楚点。其实mindspore的Dense和pytorch的Linear的公式实现是一样的。
    附上pytorch的实现:
#@title Default title text 
import torch
import torch_npu
import torch.nn as nn
class Self_Attention(torch.nn.Module):def __init__(self, dim):super(Self_Attention, self).__init__() #  其中qkv代表构建好训练好的wq,wk,wv的权重参数;self.scale = 2 ** -0.5self.q = torch.nn.Linear(dim, 2, bias=False) q_list = [[1., 0.],[1., 1.]]self.q.weight.data = torch.Tensor(q_list).Tprint('q value:', self.q.weight.data)self.k = nn.Linear(dim, 2, bias=False)k_list = [[0., 1.], [1., 1.]]self.k.weight.data = torch.Tensor(k_list).Tprint('k value:', self.k.weight.data)self.v = nn.Linear(dim,3,bias=False)v_list = [[0., 1., 1.],[1., 0., 0.]]# print("origin shape:", self.v.weight.data.shape)self.v.weight.data = torch.Tensor(v_list).Tprint('init shape:',self.v.weight.data)def forward(self, x):q = self.q(x)  # 通过训练好的参数生成q参数print("q:", q)k = self.k(x)print("k:", k)v = self.v(x)print("v shape:", v.shape)# Att公式attn = (q.matmul(k.T)) / torch.sqrt(torch.tensor(2.0))print("attn1:", attn)# attn = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(2.0))# print("attn11:", attn)# attn = (q @ k.transpose(-2, -1)) * self.scale# print("attn2:", attn)attn = attn.softmax(dim=-1)print("softmax attn:", attn)# print(attn.shape) # shape[4,4]x = attn @ vprint(x.shape)  #shape[4,3]return x 
x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = torch.Tensor(x)
att = Self_Attention(2)  
att(x)

这篇关于SelfAttention|自注意力机制ms简单实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/712440

相关文章

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

Java枚举类实现Key-Value映射的多种实现方式

《Java枚举类实现Key-Value映射的多种实现方式》在Java开发中,枚举(Enum)是一种特殊的类,本文将详细介绍Java枚举类实现key-value映射的多种方式,有需要的小伙伴可以根据需要... 目录前言一、基础实现方式1.1 为枚举添加属性和构造方法二、http://www.cppcns.co

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java实现文件图片的预览和下载功能

《Java实现文件图片的预览和下载功能》这篇文章主要为大家详细介绍了如何使用Java实现文件图片的预览和下载功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... Java实现文件(图片)的预览和下载 @ApiOperation("访问文件") @GetMapping("

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

Mysql表的简单操作(基本技能)

《Mysql表的简单操作(基本技能)》在数据库中,表的操作主要包括表的创建、查看、修改、删除等,了解如何操作这些表是数据库管理和开发的基本技能,本文给大家介绍Mysql表的简单操作,感兴趣的朋友一起看... 目录3.1 创建表 3.2 查看表结构3.3 修改表3.4 实践案例:修改表在数据库中,表的操作主要

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义