本文主要是介绍BERT or Transformer中,MHSA中为什么要分多个Head?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
之前面试被问过的一道题,这里整理一下~
结论:模型的表达学习能力增强了
输入到MHSA中的数据的shape应该为B × L × Embedding,B是Batch,L是序列长度
而在MHSA中,数据的shape会被拆分为多个Head,所以shape会进一步变为:
B × L × Head × Little_Embedding
以Transformer为例,原始论文中Embedding为512,Head数为8,所以shape在进入MHSA中时,会变为:
B × L × 8 × 64
如果不分头,相当于对512*512的矩阵进行Attention计算;
而如果分头了,相当于8个头中,每个头彼此独立进行Attention计算,不同头学习到的特征也可能是不同的,相当于增强了模型的表达学习能力。
并且,8 次的 64 × 64,和 1 次的 512 × 512,两者的计算复杂度是一致的,并没有造成额外的计算开销。

这篇关于BERT or Transformer中,MHSA中为什么要分多个Head?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!