TensorFlow MatMul操作rank错误问题记录

2024-08-21 20:08

本文主要是介绍TensorFlow MatMul操作rank错误问题记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个问题应该算是很简单的,只不过我是新手,需要多记录下。在看Stanford的TensorFlow教程(地址为:https://www.youtube.com/watch?v=g-EvyKpZjmQ&list=PLQ0sVbIj3URf94DQtGPJV629ctn2c1zN-)Lecture 1的一段代码的时候,发现并不能运行:

import tensorflow as tfwith tf.device('/gpu:1'):a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='a')b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='b')c = tf.matmul(a, b)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))

报错为:ValueError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [6], [6].

TensorFlow才接触不久,基本都是运行下别人的代码,看看效果,所以对其中的方法也都是混个脸熟,并不十分清楚。这里的tf.matmul()方法和另一个tf.mul()要区分下,tf.mul实际上在新版的TensorFlow中已经修改为tf.multiply()了,我是参考https://blog.csdn.net/liuyuemaicha/article/details/70305678这篇博文学习的,测试下multiply:

import tensorflow as tfa = tf.get_variable('a', [2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
b = tf.get_variable('b', [2, 3], initializer=tf.constant_initializer(2))
c = tf.get_variable('c', [3, 2], initializer=tf.ones_initializer())init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)print('a:\n', sess.run(a))print('b:\n', sess.run(b))print('c:\n', sess.run(c))print('multiply a, b')print(sess.run(tf.multiply(a, b)))print('matmul a, c')print(sess.run(tf.matmul(a, c)))

tf.get_variable()方法的使用第一个参数是name,第二个是shape,第三个是initializer。tf.random_normal_initializer()方法就是返回一个具有正态分布的张量初始化器,均值(期望值)mean默认为0,标准差默认为1,也就是默认为标准正态分布。得到的结果为:

a:
 [[-1.2580129   0.42341614  0.2203044 ]
 [-1.1805797  -1.8744725  -0.1812443 ]]
b:
 [[2. 2. 2.]
 [2. 2. 2.]]
c:
 [[1. 1.]
 [1. 1.]
 [1. 1.]]
multiply a, b
[[-2.5160258  0.8468323  0.4406088]
 [-2.3611593 -3.748945  -0.3624886]]
matmul a, c
[[-0.6142924 -0.6142924]
 [-3.2362967 -3.2362967]]

可以看到tf.multiply()方法是对应位置元素直接相乘的,因此要求二者的shape相等,该操作也成为哈达马积(Hadamard)。a和c两个变量一个是2行3列,一个3行2列,可以用tf.matmul()方法求矩阵乘积,得到了2行2列的一个矩阵。

回到刚刚的问题,比如参考https://blog.csdn.net/blythe0107/article/details/74171870,可以采用reshape的方式,使前者的列等于后者的行也就行了,如下:

import tensorflow as tf
import numpy as npwith tf.device('/gpu:0'):a = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='a')b = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='b')c = tf.matmul(a, b)with tf.device('/gpu:1'):d = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='d')e = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='e')f = tf.matmul(d, e)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))
print(sess.run(f))

这样得到的输出如下:

2018-08-02 15:52:42.801535: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 1:   Y N
2018-08-02 15:52:42.801871: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10388 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1)
2018-08-02 15:52:42.905229: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10407 MB memory) -> physical GPU (device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1)
Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1
2018-08-02 15:52:43.010702: I tensorflow/core/common_runtime/direct_session.cc:288] Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1

MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011677: I tensorflow/core/common_runtime/placer.cc:886] MatMul: (MatMul)/job:localhost/replica:0/task:0/device:GPU:0
MatMul_1: (MatMul): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011720: I tensorflow/core/common_runtime/placer.cc:886] MatMul_1: (MatMul)/job:localhost/replica:0/task:0/device:GPU:1
a: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011741: I tensorflow/core/common_runtime/placer.cc:886] a: (Const)/job:localhost/replica:0/task:0/device:GPU:0
b: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011760: I tensorflow/core/common_runtime/placer.cc:886] b: (Const)/job:localhost/replica:0/task:0/device:GPU:0
d: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011778: I tensorflow/core/common_runtime/placer.cc:886] d: (Const)/job:localhost/replica:0/task:0/device:GPU:1
e: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011795: I tensorflow/core/common_runtime/placer.cc:886] e: (Const)/job:localhost/replica:0/task:0/device:GPU:1
[[22. 28.]
 [49. 64.]]
[[22. 28.]
 [49. 64.]]

可以看到,变量和op可以指定GPU,本例中a和b用了GPU0,另外也处理了matmul()的操作。而d和e即计算f的任务则放在了GPU1上,这个可能算是最简单了单主机多GPU使用了。

关于前面的变量使用,记录如下。

TensorFlow有两个关于variable的op,即tf.Variable()和tf.get_variable(),这里参考

https://blog.csdn.net/u012436149/article/details/53696970/学习下。比如下面的代码:

import tensorflow as tfw_1 = tf.Variable(3, name='w_1')
w_2 = tf.Variable(1, name='w_1')print(w_1.name)
print(w_2.name)init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)sess.run(tf.Print(w_1, [w_1, w_1.name, str(w_1.value)]))sess.run(tf.Print(w_2, [w_2, w_2.name, str(w_2.value)]))

这里使用了tf.Print()方法来输出一些调试信息,其value部分用str()方法处理下不然报错。输出结果:

w_1:0
w_1_1:0

[3][w_1:0][<bound method Variable.value of <tf.Variable \'w_1:0\' shape=() dtype=int32_ref>>]
[1][w_1_1:0][<bound method Variable.value of <tf.Variable \'w_1_1:0\' shape=() dtype=int32_ref>>]

使用tf.Variable()系统会自动处理命名冲突,这里如果用tf.get_variable()则会报错w_1变量已存在。所以当我们需要共享变量的时候,用tf.get_variable()。关于其实质区别,看下这段代码:

import tensorflow as tfwith tf.variable_scope('scope1'):w1 = tf.get_variable('w1', shape=[])w2 = tf.Variable(0.0, name='w_1')with tf.variable_scope('scope1', reuse=True):w1_p = tf.get_variable('w1', shape=[])w2_p = tf.Variable(1.0, name='w2')print(w1 is w1_p, w2 is w2_p)

输出为True False。由于tf.Variable()每次都在创建新对象,所有reuse=True 和它并没有什么关系。对于get_variable(),如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。

 

这篇关于TensorFlow MatMul操作rank错误问题记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1094153

相关文章

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

MyBatis模糊查询报错:ParserException: not supported.pos 问题解决

《MyBatis模糊查询报错:ParserException:notsupported.pos问题解决》本文主要介绍了MyBatis模糊查询报错:ParserException:notsuppo... 目录问题描述问题根源错误SQL解析逻辑深层原因分析三种解决方案方案一:使用CONCAT函数(推荐)方案二:

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

Redis 热 key 和大 key 问题小结

《Redis热key和大key问题小结》:本文主要介绍Redis热key和大key问题小结,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、什么是 Redis 热 key?热 key(Hot Key)定义: 热 key 常见表现:热 key 的风险:二、

IntelliJ IDEA 中配置 Spring MVC 环境的详细步骤及问题解决

《IntelliJIDEA中配置SpringMVC环境的详细步骤及问题解决》:本文主要介绍IntelliJIDEA中配置SpringMVC环境的详细步骤及问题解决,本文分步骤结合实例给大... 目录步骤 1:创建 Maven Web 项目步骤 2:添加 Spring MVC 依赖1、保存后执行2、将新的依赖

Spring 中的循环引用问题解决方法

《Spring中的循环引用问题解决方法》:本文主要介绍Spring中的循环引用问题解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录什么是循环引用?循环依赖三级缓存解决循环依赖二级缓存三级缓存本章来聊聊Spring 中的循环引用问题该如何解决。这里聊

Spring Boot中JSON数值溢出问题从报错到优雅解决办法

《SpringBoot中JSON数值溢出问题从报错到优雅解决办法》:本文主要介绍SpringBoot中JSON数值溢出问题从报错到优雅的解决办法,通过修改字段类型为Long、添加全局异常处理和... 目录一、问题背景:为什么我的接口突然报错了?二、为什么会发生这个错误?1. Java 数据类型的“容量”限制

关于MongoDB图片URL存储异常问题以及解决

《关于MongoDB图片URL存储异常问题以及解决》:本文主要介绍关于MongoDB图片URL存储异常问题以及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录MongoDB图片URL存储异常问题项目场景问题描述原因分析解决方案预防措施js总结MongoDB图

Python ZIP文件操作技巧详解

《PythonZIP文件操作技巧详解》在数据处理和系统开发中,ZIP文件操作是开发者必须掌握的核心技能,Python标准库提供的zipfile模块以简洁的API和跨平台特性,成为处理ZIP文件的首选... 目录一、ZIP文件操作基础三板斧1.1 创建压缩包1.2 解压操作1.3 文件遍历与信息获取二、进阶技

SpringBoot项目中报错The field screenShot exceeds its maximum permitted size of 1048576 bytes.的问题及解决

《SpringBoot项目中报错ThefieldscreenShotexceedsitsmaximumpermittedsizeof1048576bytes.的问题及解决》这篇文章... 目录项目场景问题描述原因分析解决方案总结项目场景javascript提示:项目相关背景:项目场景:基于Spring