本文主要是介绍自注意力机制函数(SelfAttention)python实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Self-Attention。和Attention类似都是一种注意力机制。不同的是Attention是source对target,输入的source和输出的target内容不同。例如英译中,输入英文,输出中文。而Self-Attention是source对source,是source内部元素之间或者target内部元素之间发生的Attention机制,也可以理解为Target=Source这种特殊情况下的注意力机制。
import torch
import torch.nn as nnclass SelfAttention(nn.Module):""" Self-Attention """def __init__(self, n_head, d_k, d_v, d_x, d_o):super(SelfAttention, self).__init__()self.wq = nn.Parameter(torch.Tensor(d_x, d_k))self.wk = nn.Parameter(torch.Tensor(d_x, d_k))self.wv = nn.Parameter(torch.Tensor(d_x, d_v))self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)self.init_parameters()def init_parameters(self):for param in self.parameters():stdv = 1. / np.power(param.size(-1), 0.5)param.data.uniform_(-stdv, stdv)def forward(self, x, mask=None):q = torch.matmul(x, self.wq) k = torch.matmul(x, self.wk)v = torch.matmul(x, self.wv)attn, output = self.mha(q, k, v, mask=mask)return attn, outputif __name__ == "__main__":n_x = 4d_x = 80batch = 2x = torch.randn(batch, n_x, d_x)mask = torch.zeros(batch, n_x, n_x).bool()selfattn = SelfAttention(n_head=8, d_k=128, d_v=64, d_x=80, d_o=80)attn, output = selfattn(x, mask=mask)print(attn.size())print(output.size())
运行结果:
torch.Size([16, 4, 4])
torch.Size([2, 4, 80])
这篇关于自注意力机制函数(SelfAttention)python实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!