TensorFlow MatMul操作rank错误问题记录

2024-08-21 20:08

本文主要是介绍TensorFlow MatMul操作rank错误问题记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个问题应该算是很简单的,只不过我是新手,需要多记录下。在看Stanford的TensorFlow教程(地址为:https://www.youtube.com/watch?v=g-EvyKpZjmQ&list=PLQ0sVbIj3URf94DQtGPJV629ctn2c1zN-)Lecture 1的一段代码的时候,发现并不能运行:

import tensorflow as tfwith tf.device('/gpu:1'):a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='a')b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='b')c = tf.matmul(a, b)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))

报错为:ValueError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [6], [6].

TensorFlow才接触不久,基本都是运行下别人的代码,看看效果,所以对其中的方法也都是混个脸熟,并不十分清楚。这里的tf.matmul()方法和另一个tf.mul()要区分下,tf.mul实际上在新版的TensorFlow中已经修改为tf.multiply()了,我是参考https://blog.csdn.net/liuyuemaicha/article/details/70305678这篇博文学习的,测试下multiply:

import tensorflow as tfa = tf.get_variable('a', [2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
b = tf.get_variable('b', [2, 3], initializer=tf.constant_initializer(2))
c = tf.get_variable('c', [3, 2], initializer=tf.ones_initializer())init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)print('a:\n', sess.run(a))print('b:\n', sess.run(b))print('c:\n', sess.run(c))print('multiply a, b')print(sess.run(tf.multiply(a, b)))print('matmul a, c')print(sess.run(tf.matmul(a, c)))

tf.get_variable()方法的使用第一个参数是name,第二个是shape,第三个是initializer。tf.random_normal_initializer()方法就是返回一个具有正态分布的张量初始化器,均值(期望值)mean默认为0,标准差默认为1,也就是默认为标准正态分布。得到的结果为:

a:
 [[-1.2580129   0.42341614  0.2203044 ]
 [-1.1805797  -1.8744725  -0.1812443 ]]
b:
 [[2. 2. 2.]
 [2. 2. 2.]]
c:
 [[1. 1.]
 [1. 1.]
 [1. 1.]]
multiply a, b
[[-2.5160258  0.8468323  0.4406088]
 [-2.3611593 -3.748945  -0.3624886]]
matmul a, c
[[-0.6142924 -0.6142924]
 [-3.2362967 -3.2362967]]

可以看到tf.multiply()方法是对应位置元素直接相乘的,因此要求二者的shape相等,该操作也成为哈达马积(Hadamard)。a和c两个变量一个是2行3列,一个3行2列,可以用tf.matmul()方法求矩阵乘积,得到了2行2列的一个矩阵。

回到刚刚的问题,比如参考https://blog.csdn.net/blythe0107/article/details/74171870,可以采用reshape的方式,使前者的列等于后者的行也就行了,如下:

import tensorflow as tf
import numpy as npwith tf.device('/gpu:0'):a = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='a')b = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='b')c = tf.matmul(a, b)with tf.device('/gpu:1'):d = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(2, 3), name='d')e = tf.constant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2), name='e')f = tf.matmul(d, e)sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))print(sess.run(c))
print(sess.run(f))

这样得到的输出如下:

2018-08-02 15:52:42.801535: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 1:   Y N
2018-08-02 15:52:42.801871: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10388 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1)
2018-08-02 15:52:42.905229: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1084] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10407 MB memory) -> physical GPU (device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1)
Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1
2018-08-02 15:52:43.010702: I tensorflow/core/common_runtime/direct_session.cc:288] Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:21:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:2d:00.0, compute capability: 6.1

MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011677: I tensorflow/core/common_runtime/placer.cc:886] MatMul: (MatMul)/job:localhost/replica:0/task:0/device:GPU:0
MatMul_1: (MatMul): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011720: I tensorflow/core/common_runtime/placer.cc:886] MatMul_1: (MatMul)/job:localhost/replica:0/task:0/device:GPU:1
a: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011741: I tensorflow/core/common_runtime/placer.cc:886] a: (Const)/job:localhost/replica:0/task:0/device:GPU:0
b: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2018-08-02 15:52:43.011760: I tensorflow/core/common_runtime/placer.cc:886] b: (Const)/job:localhost/replica:0/task:0/device:GPU:0
d: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011778: I tensorflow/core/common_runtime/placer.cc:886] d: (Const)/job:localhost/replica:0/task:0/device:GPU:1
e: (Const): /job:localhost/replica:0/task:0/device:GPU:1
2018-08-02 15:52:43.011795: I tensorflow/core/common_runtime/placer.cc:886] e: (Const)/job:localhost/replica:0/task:0/device:GPU:1
[[22. 28.]
 [49. 64.]]
[[22. 28.]
 [49. 64.]]

可以看到,变量和op可以指定GPU,本例中a和b用了GPU0,另外也处理了matmul()的操作。而d和e即计算f的任务则放在了GPU1上,这个可能算是最简单了单主机多GPU使用了。

关于前面的变量使用,记录如下。

TensorFlow有两个关于variable的op,即tf.Variable()和tf.get_variable(),这里参考

https://blog.csdn.net/u012436149/article/details/53696970/学习下。比如下面的代码:

import tensorflow as tfw_1 = tf.Variable(3, name='w_1')
w_2 = tf.Variable(1, name='w_1')print(w_1.name)
print(w_2.name)init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)sess.run(tf.Print(w_1, [w_1, w_1.name, str(w_1.value)]))sess.run(tf.Print(w_2, [w_2, w_2.name, str(w_2.value)]))

这里使用了tf.Print()方法来输出一些调试信息,其value部分用str()方法处理下不然报错。输出结果:

w_1:0
w_1_1:0

[3][w_1:0][<bound method Variable.value of <tf.Variable \'w_1:0\' shape=() dtype=int32_ref>>]
[1][w_1_1:0][<bound method Variable.value of <tf.Variable \'w_1_1:0\' shape=() dtype=int32_ref>>]

使用tf.Variable()系统会自动处理命名冲突,这里如果用tf.get_variable()则会报错w_1变量已存在。所以当我们需要共享变量的时候,用tf.get_variable()。关于其实质区别,看下这段代码:

import tensorflow as tfwith tf.variable_scope('scope1'):w1 = tf.get_variable('w1', shape=[])w2 = tf.Variable(0.0, name='w_1')with tf.variable_scope('scope1', reuse=True):w1_p = tf.get_variable('w1', shape=[])w2_p = tf.Variable(1.0, name='w2')print(w1 is w1_p, w2 is w2_p)

输出为True False。由于tf.Variable()每次都在创建新对象,所有reuse=True 和它并没有什么关系。对于get_variable(),如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。

 

这篇关于TensorFlow MatMul操作rank错误问题记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1094153

相关文章

线上Java OOM问题定位与解决方案超详细解析

《线上JavaOOM问题定位与解决方案超详细解析》OOM是JVM抛出的错误,表示内存分配失败,:本文主要介绍线上JavaOOM问题定位与解决方案的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一、OOM问题核心认知1.1 OOM定义与技术定位1.2 OOM常见类型及技术特征二、OOM问题定位工具

Python正则表达式匹配和替换的操作指南

《Python正则表达式匹配和替换的操作指南》正则表达式是处理文本的强大工具,Python通过re模块提供了完整的正则表达式功能,本文将通过代码示例详细介绍Python中的正则匹配和替换操作,需要的朋... 目录基础语法导入re模块基本元字符常用匹配方法1. re.match() - 从字符串开头匹配2.

Vue3绑定props默认值问题

《Vue3绑定props默认值问题》使用Vue3的defineProps配合TypeScript的interface定义props类型,并通过withDefaults设置默认值,使组件能安全访问传入的... 目录前言步骤步骤1:使用 defineProps 定义 Props步骤2:设置默认值总结前言使用T

Java实现在Word文档中添加文本水印和图片水印的操作指南

《Java实现在Word文档中添加文本水印和图片水印的操作指南》在当今数字时代,文档的自动化处理与安全防护变得尤为重要,无论是为了保护版权、推广品牌,还是为了在文档中加入特定的标识,为Word文档添加... 目录引言Spire.Doc for Java:高效Word文档处理的利器代码实战:使用Java为Wo

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Debian 13升级后网络转发等功能异常怎么办? 并非错误而是管理机制变更

《Debian13升级后网络转发等功能异常怎么办?并非错误而是管理机制变更》很多朋友反馈,更新到Debian13后网络转发等功能异常,这并非BUG而是Debian13Trixie调整... 日前 Debian 13 Trixie 发布后已经有众多网友升级到新版本,只不过升级后发现某些功能存在异常,例如网络转

sysmain服务可以禁用吗? 电脑sysmain服务关闭后的影响与操作指南

《sysmain服务可以禁用吗?电脑sysmain服务关闭后的影响与操作指南》在Windows系统中,SysMain服务(原名Superfetch)作为一个旨在提升系统性能的关键组件,一直备受用户关... 在使用 Windows 系统时,有时候真有点像在「开盲盒」。全新安装系统后的「默认设置」,往往并不尽编

Web服务器-Nginx-高并发问题

《Web服务器-Nginx-高并发问题》Nginx通过事件驱动、I/O多路复用和异步非阻塞技术高效处理高并发,结合动静分离和限流策略,提升性能与稳定性... 目录前言一、架构1. 原生多进程架构2. 事件驱动模型3. IO多路复用4. 异步非阻塞 I/O5. Nginx高并发配置实战二、动静分离1. 职责2

解决升级JDK报错:module java.base does not“opens java.lang.reflect“to unnamed module问题

《解决升级JDK报错:modulejava.basedoesnot“opensjava.lang.reflect“tounnamedmodule问题》SpringBoot启动错误源于Jav... 目录问题描述原因分析解决方案总结问题描述启动sprintboot时报以下错误原因分析编程异js常是由Ja

Python自动化处理PDF文档的操作完整指南

《Python自动化处理PDF文档的操作完整指南》在办公自动化中,PDF文档处理是一项常见需求,本文将介绍如何使用Python实现PDF文档的自动化处理,感兴趣的小伙伴可以跟随小编一起学习一下... 目录使用pymupdf读写PDF文件基本概念安装pymupdf提取文本内容提取图像添加水印使用pdfplum