本文主要是介绍transformer中的ffn,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
##
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
# 定义FFN层
class FeedForwardNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(FeedForwardNetwork, self).__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 测试FFN层
def test_ffn():
input_dim = 4
hidden_dim = 8
output_dim = 4
batch_size = 5
seq_length = 6
# 创建FFN层
ffn = FeedForwardNetwork(input_dim, hidden_dim, output_dim)
# 创建随机输入数据 (batch_size, seq_length, input_dim)
input_data = torch.randn(batch_size, seq_length, input_dim)
print(input_data)
# 前向传播
output_data = ffn(input_data)
print("Input shape:", input_data.shape)
print("Output shape:", output_data.shape)
if __name__ == "__main__":
test_ffn()
这篇关于transformer中的ffn的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!