Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

本文主要是介绍Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Inputos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,yy_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [#unit 1layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 2layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 3layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 4layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 5layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),
]def main():#[b, 32, 32, 3] => [b, 1, 1, 512]cnn_net = Sequential(cnn_layers)cnn_net.build(input_shape=[None, 32, 32, 3])#测试一下卷积层的输出#x = tf.random.normal([4, 32, 32, 3])#out = cnn_net(x)#print(out.shape)#创建全连接层, 输出为100分类fc_net = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(100, activation=None),])fc_net.build(input_shape=[None, 512])#设置优化器optimizer = optimizers.Adam(learning_rate=1e-4)#记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如# [1, 2] + [3, 4] => [1, 2, 3, 4]variables = cnn_net.trainable_variables + fc_net.trainable_variables#进行训练num_epoches = 10for epoch in range(num_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:#[b, 32, 32, 3] => [b, 1, 1, 512]out = cnn_net(x)#flatten打平 => [b, 512]out = tf.reshape(out, [-1, 512])#使用全连接层做100分类logits输出#[b, 512] => [b, 100]logits = fc_net(out)#标签做one_hot encodingy_onehot = tf.one_hot(y, depth=100)#计算损失loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)#计算梯度grads = tape.gradient(loss, variables)#更新参数optimizer.apply_gradients(zip(grads, variables))if (step % 100 == 0):print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))#进行验证total_samples = 0total_correct = 0for x,y in test_db:out = cnn_net(x)out = tf.reshape(out, [-1, 512])logits = fc_net(out)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_samples += x.shape[0]total_correct += int(correct)#统计准确率acc = total_correct / total_samplesprint("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':main()

运行结果:

这篇关于Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/918731

相关文章

使用Python实现一键隐藏屏幕并锁定输入

《使用Python实现一键隐藏屏幕并锁定输入》本文主要介绍了使用Python编写一个一键隐藏屏幕并锁定输入的黑科技程序,能够在指定热键触发后立即遮挡屏幕,并禁止一切键盘鼠标输入,这样就再也不用担心自己... 目录1. 概述2. 功能亮点3.代码实现4.使用方法5. 展示效果6. 代码优化与拓展7. 总结1.

使用Python开发一个简单的本地图片服务器

《使用Python开发一个简单的本地图片服务器》本文介绍了如何结合wxPython构建的图形用户界面GUI和Python内建的Web服务器功能,在本地网络中搭建一个私人的,即开即用的网页相册,文中的示... 目录项目目标核心技术栈代码深度解析完整代码工作流程主要功能与优势潜在改进与思考运行结果总结你是否曾经

Linux中的计划任务(crontab)使用方式

《Linux中的计划任务(crontab)使用方式》:本文主要介绍Linux中的计划任务(crontab)使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、前言1、linux的起源与发展2、什么是计划任务(crontab)二、crontab基础1、cro

kotlin中const 和val的区别及使用场景分析

《kotlin中const和val的区别及使用场景分析》在Kotlin中,const和val都是用来声明常量的,但它们的使用场景和功能有所不同,下面给大家介绍kotlin中const和val的区别,... 目录kotlin中const 和val的区别1. val:2. const:二 代码示例1 Java

C++变换迭代器使用方法小结

《C++变换迭代器使用方法小结》本文主要介绍了C++变换迭代器使用方法小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录1、源码2、代码解析代码解析:transform_iterator1. transform_iterat

C++中std::distance使用方法示例

《C++中std::distance使用方法示例》std::distance是C++标准库中的一个函数,用于计算两个迭代器之间的距离,本文主要介绍了C++中std::distance使用方法示例,具... 目录语法使用方式解释示例输出:其他说明:总结std::distance&n编程bsp;是 C++ 标准

Python获取中国节假日数据记录入JSON文件

《Python获取中国节假日数据记录入JSON文件》项目系统内置的日历应用为了提升用户体验,特别设置了在调休日期显示“休”的UI图标功能,那么问题是这些调休数据从哪里来呢?我尝试一种更为智能的方法:P... 目录节假日数据获取存入jsON文件节假日数据读取封装完整代码项目系统内置的日历应用为了提升用户体验,

vue使用docxtemplater导出word

《vue使用docxtemplater导出word》docxtemplater是一种邮件合并工具,以编程方式使用并处理条件、循环,并且可以扩展以插入任何内容,下面我们来看看如何使用docxtempl... 目录docxtemplatervue使用docxtemplater导出word安装常用语法 封装导出方

Linux换行符的使用方法详解

《Linux换行符的使用方法详解》本文介绍了Linux中常用的换行符LF及其在文件中的表示,展示了如何使用sed命令替换换行符,并列举了与换行符处理相关的Linux命令,通过代码讲解的非常详细,需要的... 目录简介检测文件中的换行符使用 cat -A 查看换行符使用 od -c 检查字符换行符格式转换将

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.