本文主要是介绍Pytorch基础:Tensor的flatten方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
相关阅读
Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,flatten是Tensor的一个重要方法,同时它也是一个torch模块中的一个函数,它们的语法如下所示。
Tensor.flatten(start_dim=0, end_dim=-1) → Tensor
torch.flatten(input, start_dim=0, end_dim=-1) → Tensorinput (Tensor) – the input tensor
start_dim (int) – the first dim to flatten
end_dim (int) – the last dim to flatten
flatten函数(或方法)用于将一个张量以特定方法展平, 如果传递了一个参数,则会将从start_dim到end_dim之间的维度展开。默认情况下,flatten将从第0维展平至最后1维。
可以看几个例子以更好的理解:
import torch# 创建一个张量
x = torch.rand(3, 3, 3)# 使用flatten函数,展平x张量
y=x.flatten()
print(x)
tensor([[[0.2581, 0.8408, 0.0216],[0.6353, 0.9141, 0.4098],[0.6391, 0.9829, 0.3967]],[[0.2167, 0.8983, 0.6492],[0.1947, 0.4953, 0.3281],[0.1740, 0.2092, 0.2048]],[[0.3972, 0.6290, 0.3010],[0.6107, 0.5429, 0.7515],[0.7950, 0.0538, 0.8963]]])print(y)
tensor([0.2581, 0.8408, 0.0216, 0.6353, 0.9141, 0.4098, 0.6391, 0.9829, 0.3967,0.2167, 0.8983, 0.6492, 0.1947, 0.4953, 0.3281, 0.1740, 0.2092, 0.2048,0.3972, 0.6290, 0.3010, 0.6107, 0.5429, 0.7515, 0.7950, 0.0538, 0.8963])print(id(x),id(y))
1185516393792 1185516395312 # 说明两个张量对象不同print(x.storage().data_ptr(), y.storage().data_ptr())
1185641974912 1185641974912 # 说明两个张量对象里面保存的数据存储是共享的print(id(x[0,0,0]), id(y[0]))
1186163118464 1186163118464 # 进一步说明两个张量对象里面保存的数据存储是共享的
这篇关于Pytorch基础:Tensor的flatten方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!