本文主要是介绍tf.squeeze/tf.expand_dims,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在git上的FM开源代码中看到了这样子的用法
https://github.com/Aifcce/FM-FFM/blob/master/model/FM.py
batch_weights = tf.squeeze(batch_weights, axis=2)
df_v = tf.expand_dims(df_v, axis=2)
tf.squeeze是降维,把维度为1的去掉,我的理解是,这个代码把dense feature和sparse feature在数据预处理时放到了一起,在进行embedding look up时,dense feature的维度为1(index id为同一个值),因此要进行过滤。
而每个sparse feature的维度不一样,用 tf.expand_dims把维度拉齐。
这篇关于tf.squeeze/tf.expand_dims的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!