本文主要是介绍CNN记录】pytorch中flatten函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
pytorch原型
torch.flatten(input, start_dim=0, end_dim=- 1)
作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,
参数:
start_dim:开始的维度
end_dim:终止的维度,-1为最后一个轴
默认值时展平为1维
例子
1、默认参数
input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])
2、设置参数
input = torch.randn(2, 3, 4, 5)output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])
这篇关于CNN记录】pytorch中flatten函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!