本文主要是介绍STN_空间变换网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
STN_空间变换网络
深度卷积网络虽然已经在很多领域取得了较好的效果,但这些模型依旧十分脆弱。例如,对一幅图像进行平移、旋转和缩放等操作后,会使原有的模型识别准确度下降,这种现象可以理解为深度卷积网络的一个通病,一般可以从两方面入手:
一是样本多样性,数据增强,进行更多的变换,令模型见多识广,可以处理各种角度的图片。
二是样本预处理,一般会采用仿射变换对现有的图片进行修正,令后面的卷积网络专门负责处理调整后的标准图片,使模型训练起来更容易。
空间变换网络(Saptial Transformer Network,STN)模型,是仿射变换领域最基础的文字识别模型之一。该模型的功能是,在训练过程中自动学习对原始图片进行平移、缩放、旋转等扭曲变换的参数,将输入图片的内容调整变成统一的模式,以便被更好地识别。
文章目录
- STN_空间变换网络
- 一、随机生成模拟数据集
- 二、STN的组成结构
- 三、整体模型结构
- 四、训练过程
- 五、深入思考
- 六、源码
- 七、相关链接
一、随机生成模拟数据集
使用captcha库生成字符图片共分为两步:
(1)实例化captcha模块的ImageCaptcha类,并指定图片的尺寸和字体。
(2)调用ImageCaptcha类对象的generate_image方法,传入字符即可生成字符图片。
随机模拟生成10000张字符图片,每张图片(60, 60)大小,总类别数26,包含字母A—Z。这些字符图片不同于标准的手写数字体数据集:每个字符在整张图片的大小尺寸占比并不固定,并且字符会有较大幅度的倾斜。如下所示:
最后输入到网络中的训练数据的结构:
train_x:(batch, 60, 60, 3)。
train_y:(batch, 26)。
二、STN的组成结构
STN模型由3部分组成,具体如下:
1、仿射参数:一般是由一个全连接网络实现的,该网络最终输出6个数值,将其变为2*3矩阵,其中每两个为一组,分别代表仿射变换中的平移、旋转、缩放所对应的参数。
2、坐标映射:创建一个与输入图片大小相同的矩阵,把该矩阵与仿射参数矩阵相乘,把所得结果当作目标图片中每个像素点对应于原图的坐标。
3、采样器:使用坐标映射部分中每个像素点的坐标值,在原始图片中取相应的像素,并通过加权系数将其填充到目标图片中,最终得到整幅目标图片。
本质上,STN 网络无非是自定义神经网络层重新实现了一遍图像的仿射变换算法。但不同的是,它将仿射变换需要的6个参数融合到神经网络思想中。一是可以利用神经网络强大的模型复杂度和自由度,设置许多的全连接权重,加权输出六个仿射变换系数;二是借助有监督的训练方式,可以根据图片提取到的特征,在不同情况下自适应生成仿射变化参数,进行合适的仿射变换。
图像仿射变换算法的原理如下:
第1步:对整张图片进行网格划分,按[-1, 1]生成一系列像素位置的坐标,生成(3, n)坐标矩阵。
第2步:将仿射变换的6个参数,化为(2, 3)的参数矩阵[[a11,a12, a13]; [a21, a22, a23]],其中第1列-第2列记录旋转和缩放信息,第3列记录平移信息。与原始坐标矩阵相乘:
(x_new, y_new).T = [[a11,a12, a13]; [a21, a22, a23]] * (x, y, 1).T
x_new = a11 * x + a12 * y + a13
y_new = a21 * x + a22 * y + a23
由此生成仿射变换后的新坐标。
第3步:对每个新坐标计算上下左右的坐标位置,还原到原图像对应的索引,取出其像素值,按照面积加权将这四个位置的像素值进行融合,作为新坐标位置的像素值。
三、整体模型结构
在实际应用中,通常把STN网络嵌入到其他模型的初始位置,用以提升分类模型的精确度。
整个网络处理流程如下:
(1)原始图片进行卷积池化,提取得到一系列特征。
(2)将特征传入全连接层,输出6个仿射变换参数,利用STN层对原图做仿射变换。
(3)将仿射变换后的图片传入后续模型,搭建分类网络。
整体模型结构如下:
第一部分:模型输入。
image = Input(shape=(60, 60, 3))。
第二部分:卷积池化提取特征,用以拟合输出6个仿射变换参数。
Conv2D(512, 5, strides=3, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
Conv2D(256, 3, strides=2, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
Conv2D(64, 3, strides=2, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
第三部分:基于生成的6个仿射变换参数,利用STN网络进行图像矫正。
Conv2D(20, (3, 3), strides=(1, 1), padding=‘same’, activation=‘relu’)。
GlobalAveragePooling2D()。
Dense(6, kernel_initializer=‘zeros’, bias_initializer=tf.keras.initializers.constant([[1.0, 0, 0], [0, 1.0, 0]]))。
stn_transformer(sampling_size, name=‘stn_transformer’)([image, loc_net])。
第四部分:将矫正后的图片传入后续分类模型。
Conv2D(512, 5, strides=3, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
Conv2D(256, 3, strides=2, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
Conv2D(64, 3, strides=2, padding=‘same’, activation=‘relu’)。
BatchNormalization(renorm=True)。
Dropout(0.2)。
第五部分:模型输出。
Conv2D(num_classes, (3, 3), strides=(1, 1), padding=‘same’, activation=‘relu’)。
GlobalAveragePooling2D()。
Activation(‘softmax’)。
四、训练过程
训练400个epoch,优化器optimizer=‘adam’,最佳模型的分类精度达到99.4%,非常满意。
与不加STN层的普通分类网络进行对比,其特征提取结构完全相同,训练400个epoch,优化器optimizer=‘adam’,最佳分类精度只有98.5%,不如STN_CNN网络。
仔细比较了两种模型的训练结果,STN_CNN由于最开始加入了一个仿射变换预处理层,所以训练起来会更难一些,损失函数下降较为缓慢,但最终预测效果明显好于普通CNN。
将STN层对原始图片的形变效果进行可视化,的确有效,最主要的矫正有两点:一是相应字符区域经过STN层调整之后尺寸变大,二是原本字符所在位置经过调整后渐近居中。
效果如下所示:(左边是原始图像,右边是STN层矫正后的图像,中间用黄色区域分隔开来)
五、深入思考
Ques1:在下采样操作中尽可能使用步长为2的卷积操作代替池化。
相关论文中阐述过这样一个关于卷积网络缺陷的例子:一个训练好的卷积网络只能根据局部特征来处理比较接近训练数据集的图像,但在处理异常图像,比如颠倒、倾斜、其他朝向的相关图像时,卷积神经网络则会表现得很差,相关论文可在arXiv网站中搜索论文编号"1710.09829"。
造成这种现象的原因是,卷积神经网络中的池化操作弄丢了一些隐含信息,这使得它只能发现局部组建的特征,不能发现组建之间的定向关系和相对空间关系。
在卷积神经网络中,池化操作可以让局部特征更明显,但在提升局部特征的同时也弄丢了其内在的其他信息,如位置信息。如果在所处理的任务中包含组建间的位置关系,则在所搭建的卷积神经网络结构中尽量不要使用池化操作,可以在网络中将下采样行为由池化操作转换为卷积运算。常用的下采样操作会将尺寸缩小一半,对于这种情况,可以使用步长为2的卷积代替。
YOLOV3模型中的Darknet-53模型使用的就是这种技术,另外,在EfficientNet模型的输出层中,也将最大池化换成了步长为2的卷积运算。
Ques2:批量归一化与激活函数的位置关系。
批量归一化层(Batch Normalization)和激活函数的先后顺序在不同情况下会不同,其本质上取决于值域间的变换关系,由于不同的激活函数有不同的值域,不能一概而论。
针对sigmoid激活函数,当x值大于7.5或者小于-7.5时,在直角坐标系中,所对应的y值几乎不变,这表明sigmoid激活函数对过大或过小的数无法产生激活作用,这种令sigmoid激活函数失效的区间叫做sigmoid的饱和区间。这种情况下,当前层网络输出全是1或-1,下一层网络将无法在对全1或-1的特征数据进行计算,导致模型在训练中无法收敛。如果网络中有类似sigmoid这种带饱和区间的激活函数,则应该将BN层放在激活函数的前面,这样经过BN处理后的特征数据值域就变成了-1到1,再输入激活函数中,便可以正常实现非线性转换的功能。
而对于relu激活函数,应该将BN层放在relu层之后,这样不会对数据的正负比例造成影响,在保证正负比例的基础上再执行BN操作,可以使效果达到最优。
实验效果:
(1)BN放在relu前面: acc=0.474。
(2)BN放在relu前面,后面还有缩放和偏差:acc=0.478。
(3)BN放在relu后面:acc=0.499。
(4)BN放在relu后面,后面还有缩放和偏差:acc=0.493。
Ques3:应该将图片归一化到[0,1]区间还是[-1,1]区间?
在网络训练之前,必要步骤就是图片的预处理。有的模型会将图片归一化到[0,1]区间,而有的模型会将图片归一化到[-1,1]区间,实际中应该如何选择呢?
实验结果显示,对于归一化到[-1,1]区间的模型,它的收敛速度会稍慢一些,val loss更高。由此可见,将图片归一化到[0,1]区间的效果要好于[-1,1]区间,这是有原因的。
图片要归一化的值域与图片本身的值域特点有关。一般情况下,图片处理过程中的值域,一部分来自于图片本身,另一部分来自于填充值。我们在向模型输入图片数据时,先对图片尺寸做了同比例缩放,然后又向图片中填充了0。填充值应该遵循最少地改变原有数据分布的原则,要使填充值对特征运算的影响最小,一般会取原有数据的下限。正是由于填充值的影响,才使得将图片归一化到[0,1]区间要好于归一化到[-1,1]区间。如果用0进行填充,并且把图片归一化到[-1,1]区间,则会在原始图片中加入很多中间值,这会影响原始的分布。
总结起来,将图片归一化到[0,1]区间或[-1,1]区间并没有太大的区别,在选择时重点要与程序的其他部分结合起来。因为本例使用0作为填充值,所以将图片归一化到[0,1]区间;如果要将图片归一化到[-1,1]区间,应该将填充值设置为-1。
Ques4:训练时优化器的选取。
Amsgrad优化器综合性能优于Adam,实验效果如下:模型在经过400次迭代训练后,adam输出的val loss = 0.9148,远大于Amsgrad优化器的val loss = 0.7458。
SGD优化器对学习率比较挑剔。SGD是一个对学习率大小非常敏感的优化器,一旦设置的学习率不合适,训练出来的效果会很差。例如此处设置lr = 0.02训练400轮,val loss = 41.5003。
Ques5:STN 层的自适应思路。
在opencv库中有许多效果极佳的图像处理函数,但调用时往往需要我们手动输入一些超参数阈值,再基于这些参数对图片进行某些种变换。但对于不同扭曲的图片,仿射变换参数不是固定不变的,如果能使得这些参数能够自适应得到,而不用我们每次手动输入,则效果极佳。
这种自适应的思路可以通过融入神经网络的思想实现。神经网络模型有两个非常强大的能力:一是神经网络模型具有非常多的权重参数,从而具备非常高的自由度。借助全连接层的这些参数,我们能够加权得到我们想要的阈值,实现每张图片设置不同的超参数,这些超参数是基于图片特征加权计算得到的。二是有监督训练方式,可以通过大量标注,不断拟合网络权重,最终让每个网络参数都具备我们想赋予的现实意义。
STN 仿射变换的超参数是基于图片整体特征再网络加权得到的。对于一张图片,需要先利用卷积池化提取图片的各种特征,基于这个特征,学会这张图片特定的自适应仿射变换参数,然后用这些参数来调整原图,再重新进行图片分类。表面上是自适应学会阈值,实际上这种处理思路和人思考方式是一样的,先看清图片整体内容,再自行设置合适的阈值。
六、源码
main函数:
import numpy as np
import cv2
from get_data import make_data
from train import SequenceData
from train import train_network
from predict import predict_sequence
from visualize import visualize_stnif __name__ == "__main__":train_x, train_y, val_x, val_y, test_x, test_y = make_data()train_generator = SequenceData(train_x, train_y, 32)val_generator = SequenceData(val_x, val_y, 32)# train_network(train_generator, val_generator, epoch=400)# load_network_then_train(train_generator, val_generator, epoch=30,# input_name='/home/archer/8_XFD_CODE/OCR2/Logs/epoch008-loss0.123-val_loss4.745.h5',# output_name='second_weights.hdf5')predict_sequence(test_x, test_y)visualize_stn(test_x, test_y)
读取数据:
import numpy as np
import cv2
import osdef read_path():data_x = []data_y = []filename = os.listdir('img')filename.sort()for name in filename:img_path = 'img/' + namedata_x.append(img_path)obj1 = name.split('.')obj2 = obj1[0].split('_')obj3 = obj2[1]data_y.append(obj3)return data_x, data_ydef make_data():data_x, data_y = read_path()print('all image quantity : ', len(data_y)) # 10000train_x = data_x[:8000]train_y = data_y[:8000]val_x = data_x[8000:9000]val_y = data_y[8000:9000]test_x = data_x[9000:]test_y = data_y[9000:]return train_x, train_y, val_x, val_y, test_x, test_y
STN 网络结构:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import backend as kclass stn_transformer(tf.keras.layers.Layer):def __init__(self, output_size, **kwargs):self.output_size = output_sizesuper(stn_transformer, self).__init__(**kwargs)def compute_output_shape(self, input_shape):height, width = self.output_sizenum_channels = input_shape[0][-1]# input : [im, loc_net]# img : (1, 60, 60, 3),# loc_net : (1, 6)return None, height, width, num_channelsdef call(self, inputs, **kwargs):x, transformation = inputs# inputs : [shape=(1, 60, 60, 3), shape=(1, 6)]output = self.transform(x, transformation, self.output_size)return outputdef transform(self, x, affine_transformation, output_size):num_channels = x.shape[-1] # 3batch_size = k.shape(x)[0]transformations = tf.reshape(affine_transformation, shape=(batch_size, 2, 3)) # (2, 3)的坐标变换矩阵regular_grids = self.make_regular_grids(batch_size, *output_size) # (1, 3, width * height)sampled_grids = k.batch_dot(transformations, regular_grids)# (1, 2, 3) * (1, 3, width * height) = (1, 2, width * height)interpolated_image = self.interpolate(x, sampled_grids, output_size) # (width * height, 3)interpolated_image = tf.reshape(interpolated_image,tf.stack([batch_size, output_size[0], output_size[1], num_channels]))# (1, height, width, 3)return interpolated_imagedef make_regular_grids(self, batch_size, height, width):x_linspace = tf.linspace(-1.0, 1.0, width) # shape=(width,)y_linspace = tf.linspace(-1.0, 1.0, height) # shape=(height,)x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace)x_coordinates = k.flatten(x_coordinates) # shape=(width * height,)y_coordinates = k.flatten(y_coordinates) # shape=(width * height,)ones = tf.ones_like(x_coordinates) # shape=(width * height,)grid = tf.concat([x_coordinates, y_coordinates, ones], 0) # shape=(3 * width * height,)grid = k.flatten(grid)grids = k.tile(grid, k.stack([batch_size]))regular_grids = tf.reshape(grids, (batch_size, 3, height * width))# (1, 3, width * height)# regular_grids 含义是 : 共有width * height个位置,每一列都代表该像素点的坐标位置 (x, y, 1)return regular_gridsdef interpolate(self, image, sampled_grids, output_size):# image.shape : (1, 60, 60, 3)# sampled_grids.shape : (1, 2, 60 * 60)# output_size : (60, 60)batch_size = k.shape(image)[0]height = k.shape(image)[1]width = k.shape(image)[2]num_channels = k.shape(image)[3]x = tf.cast(k.flatten(sampled_grids[:, 0:1, :]), dtype='float32') # (width * height,)y = tf.cast(k.flatten(sampled_grids[:, 1:2, :]), dtype='float32') # (width * height,)# 还原映射坐标对应于原始图片的值域,由[-1, 1]到[0, width]和[0, height]x = 0.5 * (x + 1.0) * tf.cast(width, dtype='float32') # (width * height,)y = 0.5 * (y + 1.0) * tf.cast(height, dtype='float32') # (width * height,)# 将转换后的坐标变为整数,同时计算出相邻坐标x0 = k.cast(x, 'int32')x1 = x0 + 1y0 = k.cast(y, 'int32')y1 = y0 + 1# 截断出界的坐标max_x = int(k.int_shape(image)[2] - 1)max_y = int(k.int_shape(image)[1] - 1)x0 = k.clip(x0, 0, max_x) # (width * height,)x1 = k.clip(x1, 0, max_x) # (width * height,)y0 = k.clip(y0, 0, max_y) # (width * height,)y1 = k.clip(y1, 0, max_y) # (width * height,)# 适配批次处理, 因为一次性要处理一个batch的图片,而在矩阵运算中又是拉成一个维度,所以需要记录好每张图片的起始索引位置pixels_batch = k.arange(0, batch_size) * (height * width)pixels_batch = k.expand_dims(pixels_batch, axis=-1)flat_output_size = output_size[0] * output_size[1]# 沿着轴重复张量的元素base = k.repeat_elements(pixels_batch, flat_output_size, axis=1)base = k.flatten(base) # 批次中每个图片的起始索引# 计算4个点在原始图片上的索引, 因为矩阵坐标拉直成向量时,是把每一行依次拼接的,所以都是乘以width# base_y0是代表height方向坐标为y0时应该累加多少偏移,base_y1是代表height方向坐标为y1时应该累加多少偏移。# 所以对四个坐标:(x0, y0), (x0, y1), (x1, y0), (x1, y1)# (x0, y0)和(x1, y0)的索引都是累加相同的base_y0;# (x0, y1)和(x1, y1)的索引都是累加相同的base_y1。base_y0 = base + (y0 * width)base_y1 = base + (y1 * width)indices_a = base_y0 + x0 # 代表(x0, y0)位置的索引 : (width * height,)indices_b = base_y1 + x0 # 代表(x0, y1)位置的索引 : (width * height,)indices_c = base_y0 + x1 # 代表(x1, y0)位置的索引 : (width * height,)indices_d = base_y1 + x1 # 代表(x1, y1)位置的索引 : (width * height,)flat_image = tf.reshape(image, shape=(-1, num_channels)) # (width * height, 3), 每个位置记录着r、g、b三个像素值flat_image = tf.cast(flat_image, dtype='float32') # (width * height, 3)pixel_value_a = tf.gather(flat_image, indices_a) # (width * height, 3), 代表每个(x0, y0)位置对应的r、g、b三个像素值pixel_value_b = tf.gather(flat_image, indices_b) # (width * height, 3), 代表每个(x0, y1)位置对应的r、g、b三个像素值pixel_value_c = tf.gather(flat_image, indices_c) # (width * height, 3), 代表每个(x1, y0)位置对应的r、g、b三个像素值pixel_value_d = tf.gather(flat_image, indices_d) # (width * height, 3), 代表每个(x1, y1)位置对应的r、g、b三个像素值x0 = tf.cast(x0, 'float32')x1 = tf.cast(x1, 'float32')y0 = tf.cast(y0, 'float32')y1 = tf.cast(y1, 'float32')# 在对映射坐标周围的4个像素点进行采样时,是按照距离远近定义权重的,距离越近的点权重越大。# 对于加权点(x0, y0), 利用(x, y)与(x1, y1)围成的面积来代表其权重,离得越近自然面积会越大,而且四个面积和加起来正好为1。area_a = tf.expand_dims(((x1 - x) * (y1 - y)), 1) # (x0, y0)位置的权重area_b = tf.expand_dims(((x1 - x) * (y - y0)), 1) # (x0, y1)位置的权重area_c = tf.expand_dims(((x - x0) * (y1 - y)), 1) # (x1, y0)位置的权重area_d = tf.expand_dims(((x - x0) * (y - y0)), 1) # (x1, y1)位置的权重values_a = area_a * pixel_value_avalues_b = area_b * pixel_value_bvalues_c = area_c * pixel_value_cvalues_d = area_d * pixel_value_dreturn values_a + values_b + values_c + values_d# interpolate 函数的逻辑原理:# 以图像正中心为坐标原点建立直角坐标系,每个像素点都可以分配到一个坐标。
# 在利用矩阵相乘做平移、旋转、放缩后,会新得到一组坐标,就是每个像素点原坐标仿射变换后变成的新坐标。
# 但是图像变换不是平面几何,你光变坐标位置没用,还得把每个坐标位置上的原始像素值也对应变换过去。
# 出现一个问题,就是仿射变换后的新坐标未必是整数,所以根据变换后上下左右四个位置像素的加权来确定。# 这段代码难就难在:
# 1、全部借用矩阵变换来处理,对线性代数的要求极高。
# 2、引入了batch批次,一次性要对一个批次的图片同时做处理,所以索引很烦。# ic层: BatchNormalization + Dropout
def ic(inputs, p):x = BatchNormalization(renorm=True)(inputs)return Dropout(p)(x)# 普通cnn提取feature map
def cnn(x):x = Conv2D(512, 5, strides=3, padding='same', activation='relu')(x)x = ic(x, 0.2)x = Conv2D(256, 3, strides=2, padding='same', activation='relu')(x)x = ic(x, 0.2)x = Conv2D(64, 3, strides=2, padding='same', activation='relu')(x)x = ic(x, 0.2)return xdef create_model(input_shape=(60, 60, 3), sampling_size=(60, 60), num_classes=26):image = Input(shape=input_shape)x = cnn(image)x = Conv2D(20, (3, 3), strides=(1, 1), padding='same', activation='relu')(x)loc_net = GlobalAveragePooling2D()(x)loc_net = Dense(6, kernel_initializer='zeros',bias_initializer=tf.keras.initializers.constant([[1.0, 0, 0], [0, 1.0, 0]]))(loc_net)x = stn_transformer(sampling_size, name='stn_transformer')([image, loc_net])x = cnn(x)x = Conv2D(num_classes, (3, 3), strides=(1, 1), padding='same', activation='relu')(x)x = GlobalAveragePooling2D()(x)x = Activation('softmax')(x)model = Model(inputs=image, outputs=x)model.summary()return model
训练:
import cv2
import numpy as np
import string
from tensorflow.keras.utils import *
import math
from stn_model import create_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpointchar_class = string.ascii_uppercase # A-Z
width, height, n_class = 60, 60, len(char_class)
char_list = list(char_class)def vector(i):v = np.zeros(n_class)v[i] = 1return vclass SequenceData(Sequence):def __init__(self, data_x, data_y, batch_size):self.batch_size = batch_sizeself.data_x = data_xself.data_y = data_yself.indexes = np.arange(len(self.data_x))def __len__(self):return math.floor(len(self.data_x) / float(self.batch_size))def on_epoch_end(self):np.random.shuffle(self.indexes)def __getitem__(self, idx):batch_index = self.indexes[idx * self.batch_size:(idx + 1) * self.batch_size]batch_x = [self.data_x[k] for k in batch_index]batch_y = [self.data_y[k] for k in batch_index]x = np.zeros((self.batch_size, height, width, 3))y = np.zeros((self.batch_size, n_class))for i in range(self.batch_size):img = cv2.imread(batch_x[i])img1 = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)img2 = img1 / 255x[i, :, :, :] = img2char = batch_y[i]char_index = char_class.find(char)y[i, :] = vector(char_index)# print(char, char_index)# print(vector(char_index))# cv2.namedWindow("Image")# cv2.imshow("Image", img2)# cv2.waitKey(0)return x, ydef train_network(train_generator, validation_generator, epoch):model = create_model(num_classes=n_class)adam = Adam(lr=1e-4, amsgrad=True)log_dir = "Logs/"checkpoint = ModelCheckpoint(log_dir + 'epoch{epoch:03d}-train_loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss', save_weights_only=True, save_best_only=False, period=1)model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])model.fit_generator(train_generator,steps_per_epoch=len(train_generator),epochs=epoch,validation_data=validation_generator,validation_steps=len(validation_generator),callbacks=[checkpoint])model.save_weights('first_weights.hdf5')def load_network_then_train(train_generator, validation_generator, epoch, input_name, output_name):model = create_model()model.load_weights(input_name)print('网络层总数为:', len(model.layers))adam = Adam(lr=1e-4, amsgrad=True)log_dir = "Logs/"checkpoint = ModelCheckpoint(log_dir + 'epoch{epoch:03d}-train_loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss', save_weights_only=True, save_best_only=False, period=1)model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])model.fit_generator(train_generator,steps_per_epoch=len(train_generator),epochs=epoch,validation_data=validation_generator,validation_steps=len(validation_generator),callbacks=[checkpoint])model.save_weights(output_name)
预测:
import cv2
from stn_model import create_model
import numpy as np
import stringchar_class = string.ascii_uppercase
width, height, n_class = 60, 60, len(char_class)
char_list = list(char_class)def predict_sequence(test_x, test_y):predict_model = create_model(num_classes=n_class)predict_model.load_weights('best_val_loss0.008.h5')acc_count = 0 # 统计正确的序列个数for i in range(len(test_x)):img = cv2.imread(test_x[i])img1 = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)img2 = img1 / 255img3 = img2[np.newaxis, :, :, :]result = predict_model.predict(img3) # (1, 26)index = int(np.argmax(result[0]))char = char_list[index]if char == test_y[i]:acc_count = acc_count + 1else:print('预测字符:', char, '真实字符:', test_y[i])# cv2.namedWindow("img2")# cv2.imshow("img2", img2)# cv2.waitKey(0)print('sequence recognition accuracy : ', acc_count / len(test_x))# 经过test_x、test_y数据集测试,算法分类精度达到99.4%,比较满意
可视化STN网络层对字符图片的矫正效果:
import cv2
from stn_model import create_model
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import numpy as np
import stringchar_class = string.ascii_uppercase
width, height, n_class = 60, 60, len(char_class)
char_list = list(char_class)def visualize_stn(test_x, test_y):model = create_model()model.load_weights('best_val_loss0.008.h5')print(model.layers[13].output) # Tensor("stn_transformer/Identity:0", shape=(None, 60, 60, 3))new_model = Model(inputs=model.input, outputs=model.layers[13].output, name='new_model')for i in range(len(test_x)):img = cv2.imread(test_x[i])img1 = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)img2 = img1 / 255img3 = img2[np.newaxis, :, :, :]stn_img = new_model.predict(img3) # (1, 60, 60, 3)stn_img1 = stn_img[0]# 原始图片# cv2.namedWindow("img2")# cv2.imshow("img2", img2)# cv2.waitKey(0)# stn层矫正后的图片# cv2.namedWindow("stn_img")# cv2.imshow("stn_img", stn_img1)# cv2.waitKey(0)# demo_img中,左边是原始图片,右边是stn层矫正后的图片,中间用黄色区域分隔开来demo_img = np.zeros((height, 2 * width + 10, 3))demo_img[:, :width, :] = img2demo_img[:, width:(width + 10), :] = [0.0, 1.0, 1.0] # 中间间隔区域用黄色代表demo_img[:, (width + 10):, :] = stn_img1# cv2.namedWindow("demo_img")# cv2.imshow("demo_img", demo_img)# cv2.waitKey(0)cv2.imwrite('demo/' + str(i) + '.jpg', np.uint8(demo_img * 255))
七、相关链接
如果代码跑不通,或者想直接使用我自己制作的数据集,可以去下载项目链接:
https://blog.csdn.net/Twilight737
这篇关于STN_空间变换网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!