本文主要是介绍Keras实现Senet block模块,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一、keras实现的Senet block模块代码
import keras
class SeBlock(keras.layers.Layer): def __init__(self, reduction=4,**kwargs):super(SeBlock,self).__init__(**kwargs)self.reduction = reductiondef build(self,input_shape):#构建layer时需要实现#input_shape passdef call(self, inputs):x = keras.layers.GlobalAveragePooling2D()(inputs)x = keras.layers.Dense(int(x.shape[-1]) // self.reduction, use_bias=False,activation=keras.activations.relu)(x)x = keras.layers.Dense(int(inputs.shape[-1]), use_bias=False,activation=keras.activations.hard_sigmoid)(x)return keras.layers.Multiply()([inputs,x]) #给通道加权重#return inputs*x
二、Senet block模块调用
outputs=SeBlock()(inputs) #创建一个SeBlock匿名对象,使用对象()调用call方法
这篇关于Keras实现Senet block模块的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!