本文主要是介绍tf实现用二维的索引从二维数组获取对应值 tf.gather_nd,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inds = tf.constant([[0, 2], [2, 1], [1, 1]])#目的是实现 从[1,2,3]获取index为[0,2]的值也就是[1,3]作为第一行,
从[4,5,6]获取index为[2,1]的值也就是[6,5]作为第二行,
从[7,8,9]获取index[1,1]的值作为第三行,也就是输出是
[[1 3][6 5][8 8]]
这种需求应该很常见,但是想通过look_up_table好像不行,以及想通过tf.gather_fn似乎可以但是也不好写
本文提供一种写法:
import tensorflow as tfdef gather_batch(v, inds):return tf.gather(v, inds)def test2():a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])inds = tf.constant([[0, 2], [2, 1], [1, 1]])vs = tf.map_fn(fn=lambda x: gather_batch(x[:3], x[3:]), elems=tf.concat([a, inds], 1))with tf.Session() as sess:print(sess.run(vs))if __name__ == '__main__':# test1()test2()
但是上面写法还是用了循环 会很慢 所以更好写法
def test3():a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])inds = tf.constant([[0, 2], [2, 1], [1, 1]])batch_size = inds.shape[0]cnt = inds.shape[1]left_inds = tf.tile(tf.expand_dims(tf.range(batch_size), 1),[1, cnt])ind = tf.squeeze(tf.stack([tf.expand_dims(left_inds, 2),tf.expand_dims(inds, 2),],2),-1)vs = tf.gather_nd(a, ind)with tf.Session() as sess:# print(sess.run(ind))print(sess.run(vs))
这篇关于tf实现用二维的索引从二维数组获取对应值 tf.gather_nd的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!