利用tensorflow使用预训练神经网络(VGG16)来训练模型

2024-03-29 13:48

本文主要是介绍利用tensorflow使用预训练神经网络(VGG16)来训练模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用tensorflow使用预训练神经网络(VGG16)来训练模型

文章目录

  • 利用tensorflow使用预训练神经网络(VGG16)来训练模型
    • 1.预训练神经网络是什么
    • 2.数据集以及神经网络的选择
    • 3.VGG网络架构
    • 4.确定我们训练的步骤
    • 5.微调
    • 结语

1.预训练神经网络是什么

​ 预训练神经网络是提前在大型数据库上训练过的网络,他蕴含了在大型数据集上训练过的权重,我们可以将他迁移到小型数据集上从而得到较高的准确率,举个例子来说,原本的神经网络是对几百种分类的大型数据上进行学习的,我们得到的训练模型含有获得的权重,我们将他迁移到只有几种分类的小型数据上从而来完成分类识别任务,这种又称迁移学习(**所谓迁移学习,**或者领域适应Domain Adaptation,一般就是要将从源领域(Source Domain)学习到的东西应用到目标领域(Target Domain)上去。源领域和目标领域之间往往有gap/domain discrepancy(源领域的数据和目标领域的数据遵循不同的分布)。

迁移学习能够将适用于大数据的模型迁移到小数据上,实现个性化迁移。)。

​ 那么,这里我们有一个明显的问题,熟悉神经网络的同学肯定知道,我们训练模型其实就是训练神经网络的各项参数,让每个权重结合,最终能够完成得到正确的输出,那么举个例子,在识别椅子与桌子之上的数据集,上的网络,得到的权重为什么能够应用到其他地方上呢。其实我们换个地方想想,在分类问题上,到底几个分类是由我们最后的输出层(全连接层,分类器)决定的,而底层的卷积层只是负责提取特征,所以我们可以使用预训练神经网络的卷积来提取特征(这里是出自网络上搜索,以及询问他人得到的理解,有不对的可以评论区指正)。

2.数据集以及神经网络的选择

​ 迁移学习一般用于小型数据上的识别上,我们再把著名的猫狗大战数据集拿出来,该数据集由一堆猫和狗的图片来识别在这里插入图片描述

然后在预训练神经网络上的选择,keras其中有提供一堆预训练神经网络,比如有VGG16,VGG19,ResNet,Xception等许多预训练神经网络,我们在这里选择VGG16(其实我也有拿Xception来测试,但可能由于代码编写问题,拿到的正确率并不高,所以我并没有拿出来。。。)

3.VGG网络架构

​ VGG是一个十分经典的神经网络,网上资料很多,我找到了该网络的各项模型的具体层结构:在这里插入图片描述

我们其实可以自己来构造当前的层次,但是我们使用预训练神经网络主要是想拿他训练过的权重,

4.确定我们训练的步骤

​ 预训练神经网络步骤可以由下面这张图来确定在这里插入图片描述

我们获得预训练网络的卷积基,抛弃直接连接输出的分类器,冻结卷积基(就是不让卷积基础的权重随着后面我们新加入的分类器传播的权重来改变),然后根据具体的任务我们添加新的分类器从而来进行新的训练

所以我们直接编写代码如下:

import tensorflow as tf
from tensorflow import keras
from keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
#提取新的与训练神经网络我们不要输出层,只要中间提取特征的卷积层
conv_base=keras.applications.VGG16(weights='imagenet',include_top=False)#使用在imageNet上使用的权重
#include_top表示是否需要哪些全连接层
#查看一下获取的网络结构
conv_base.summary()

在这里插入图片描述

然后我们根据具体的分类器来调整我们最后的输出层,最后是二分类问题,所以我们构造后面的分类器如下:

model=keras.Sequential()
#因为我们使用了已经提前训练好的参数,我们并不希望该权重改变,所以我们要将该权重设置为不可训练
conv_base.trainable=False
model.add(conv_base)
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(512,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()

之后我们开始训练(一些其他数据读取与训练代码与博主上期博客一样),我们查看每次训练的结果:

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 1: loss: 0.329, accuracy: 0.855

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 2: loss: 0.250, accuracy: 0.893

可以看到仅仅只是,两次训练

准确率便能接近达到0.893,但是在接下来的训练上准确率是在90左右,无法上升,达到过拟合了,我们为了进一步提高准确率,可以使用微调方法。

5.微调

刚才我们说过底层的卷积基是负责提取特征的部分,含有的权重是可以迁移过来直接使用的,那么接近输出层的卷积层的权重,是不是意味着可以调整。使得我们增加我们的准确率,但是有一个前提是我们的分类器必须提前训练好的,不然随机初始化后的分类器可能会破坏我们调整的卷积基。所以我们接下来的微调都是在第一轮训练过的分类器

#使得一些高层的卷积基可以训练
conv_base.trainable=True
for layers in conv_base.layers[:-3]:layers.trainable=False
optimizer=keras.optimizers.Adam(0.00005)#同时需要调整更低的学习速率

然后我们开始我们第二波的训练

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 1: loss: 0.196, accuracy: 0.915

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------->

Epoch 2: loss: 0.136, accuracy: 0.944

可以看到,我们的正确率开始迅速上升达到95左右,说明我们的微调还是十分有用的。

结语

对于本次使用预训练神经网络中,我省略了一些代码,重点介绍了如何使用预训练神经网络以及微调上,有什么问题和独到的见解可以评论区指正,谢谢。

这篇关于利用tensorflow使用预训练神经网络(VGG16)来训练模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858759

相关文章

vue使用docxtemplater导出word

《vue使用docxtemplater导出word》docxtemplater是一种邮件合并工具,以编程方式使用并处理条件、循环,并且可以扩展以插入任何内容,下面我们来看看如何使用docxtempl... 目录docxtemplatervue使用docxtemplater导出word安装常用语法 封装导出方

Linux换行符的使用方法详解

《Linux换行符的使用方法详解》本文介绍了Linux中常用的换行符LF及其在文件中的表示,展示了如何使用sed命令替换换行符,并列举了与换行符处理相关的Linux命令,通过代码讲解的非常详细,需要的... 目录简介检测文件中的换行符使用 cat -A 查看换行符使用 od -c 检查字符换行符格式转换将

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

Elasticsearch 在 Java 中的使用教程

《Elasticsearch在Java中的使用教程》Elasticsearch是一个分布式搜索和分析引擎,基于ApacheLucene构建,能够实现实时数据的存储、搜索、和分析,它广泛应用于全文... 目录1. Elasticsearch 简介2. 环境准备2.1 安装 Elasticsearch2.2 J

使用C#代码在PDF文档中添加、删除和替换图片

《使用C#代码在PDF文档中添加、删除和替换图片》在当今数字化文档处理场景中,动态操作PDF文档中的图像已成为企业级应用开发的核心需求之一,本文将介绍如何在.NET平台使用C#代码在PDF文档中添加、... 目录引言用C#添加图片到PDF文档用C#删除PDF文档中的图片用C#替换PDF文档中的图片引言在当

Java中List的contains()方法的使用小结

《Java中List的contains()方法的使用小结》List的contains()方法用于检查列表中是否包含指定的元素,借助equals()方法进行判断,下面就来介绍Java中List的c... 目录详细展开1. 方法签名2. 工作原理3. 使用示例4. 注意事项总结结论:List 的 contain

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

Android中Dialog的使用详解

《Android中Dialog的使用详解》Dialog(对话框)是Android中常用的UI组件,用于临时显示重要信息或获取用户输入,本文给大家介绍Android中Dialog的使用,感兴趣的朋友一起... 目录android中Dialog的使用详解1. 基本Dialog类型1.1 AlertDialog(

Python使用自带的base64库进行base64编码和解码

《Python使用自带的base64库进行base64编码和解码》在Python中,处理数据的编码和解码是数据传输和存储中非常普遍的需求,其中,Base64是一种常用的编码方案,本文我将详细介绍如何使... 目录引言使用python的base64库进行编码和解码编码函数解码函数Base64编码的应用场景注意