本文主要是介绍【tensorflow】slim模块中fine-tune中的BatchNormalization的设置,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
tensorflow的BatchNorm 应该是tensorflow中最大的坑之一。大家遇到最多的问题就是在fine-tune的时候,加载一个预模型然后在训练时候发现效果良好,但是在测试的时候直接扑街。
这是因为batch normalization在训练过程中需要去计算整个样本的均值和方差,而在代码实现中,BN则是采取用移动平均(moving average)来求取批均值和批方差来,所以在每一个批度下来,都会对他的mean和var进行更新。所以在使用BN的时候,需要将moving_mean和moving_variance加入到tf.GraphKeys.UPDATE_OPS操作中。
此处以Inception v3的argscope为例:
def inception_v3_arg_scope(weight_decay=0.00004,batch_norm_var_collection='moving_vars',batch_norm_decay=0.9997,batch_norm_epsilon=0.001,updates_collections=ops.GraphKeys.UPDATE_OPS,use_fused_batchnorm=True):"""Defines the default InceptionV3 arg scope.Args:weight_decay: The weight decay to use for regularizing the model.batch_norm_var_collection: The name of the collection for the batch normvariables.batch_norm_decay: Decay for batch norm moving averagebatch_norm_epsilon: Small float added to variance to avoid division by zeroupdates_collections: Collections for the update ops of the layeruse_fused_batchnorm: Enable fused batchnorm.Returns:An `arg_scope` to use for the inception v3 model."""batch_norm_params = {# Decay for the moving averages.'decay': batch_norm_decay,# epsilon to prevent 0s in variance.'epsilon': batch_norm_epsilon,# collection containing update_ops.'updates_collections': updates_collections,# Use fused batch norm if possible.'fused': use_fused_batchnorm,# collection containing the moving mean and moving variance.'variables_collections': {'beta': None,'gamma': None,'moving_mean': [batch_norm_var_collection],'moving_variance': [batch_norm_var_collection],}}# Set weight_decay for weights in Conv and FC layers.with arg_scope([layers.conv2d, layers_lib.fully_connected],weights_regularizer=regularizers.l2_regularizer(weight_decay)):with arg_scope([layers.conv2d],weights_initializer=initializers.variance_scaling_initializer(),activation_fn=nn_ops.relu,normalizer_fn=layers_lib.batch_norm,normalizer_params=batch_norm_params) as sc:return sc
可以看到moving_mean和moving_variance加入到ops.GraphKeys.UPDATE_OPS, 所以需要对这个集合进行更新
代码示例:
opt = tf.train.AdamOptimizer(learning_rate=lr_v)update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies([tf.group(*update_ops)]):optimizer = opt.minimize(loss)
上面这段代码表示在求解minimize loss的时候,也需要对BN的参数进行更新。
此时,问题解决
参考:https://blog.csdn.net/qq_25737169/article/details/79616671
这篇关于【tensorflow】slim模块中fine-tune中的BatchNormalization的设置的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!