Tensorflow2.0学习(2):基于fashion_mnist数据集的分类基本步骤

2024-01-13 09:08

本文主要是介绍Tensorflow2.0学习(2):基于fashion_mnist数据集的分类基本步骤,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 导入包、打印包的信息

  其中 %matplotlib inline 是IPython中的魔法函数,作用是:在利用matplotlib.pyplot作图或创建画布时不需要plt.show(),即可实现图像的显示。

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np ,pd, sklearn, tf, keras:print(module.__name__, module.__version__)
2.1.0
sys.version_info(major=3, minor=7, micro=4, releaselevel='final', serial=0)
matplotlib 3.1.1
numpy 1.16.5
pandas 0.25.1
sklearn 0.21.3
tensorflow 2.1.0
tensorflow_core.python.keras.api._v2.keras 2.2.4-tf
  • 下载、读取、分割数据集
# 读取keras中的进阶版mnist数据集
fashion_mnist = keras.datasets.fashion_mnist
# 加载数据集,切分为训练集和测试集
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
# 从训练集中将后五千张作为验证集,前五千张作为训练集
# [:5000]默认从头开始,从头开始取5000个
# [5000:]从第5001开始,结束位置默认为最后
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
# 打印这些数据集的大小
print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 7s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 1s 0us/step
(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)
(10000, 28, 28) (10000,)

  可以得出,训练集有55000张图片,每张图片为28*28;验证集有5000张图片;测试集有1000张图片。

  • 显示图片
def show_single_image(img_arr):plt.imshow(img_arr, cmap="binary")plt.show()
# 显示训练集第一张图片    
show_single_image(x_train[0])

在这里插入图片描述

# 设置n_rows行与n_cols列用来显示图像,共显示x_data个图像
#(y_data是其标签,class_names是其真实的类名)
def show_imgs(n_rows, n_cols, x_data, y_data, class_names):# 断言:不满足条件触发异常assert len(x_data) == len(y_data)assert n_rows * n_cols <len(x_data)plt.figure(figsize=(n_cols * 1.4, n_rows * 1.6))for row in range(n_rows):for col in range(n_cols):index = n_cols * row + col# 总的图像为n_rows * n_cols,当前图像位置为index+1plt.subplot(n_rows, n_cols, index+1)plt.imshow(x_data[index], cmap="binary",interpolation = 'nearest')plt.axis('off')plt.title(class_names[y_data[index]])plt.show()
class_names = ['T-shirt','Trouser','Pullover','Dress','Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag','Ankle boot']
# 显示训练集的十五张图片              
show_imgs(3, 5, x_train, y_train, class_names)

在这里插入图片描述

  • 构建模型
# tf.keras.models.Sequential() 构建模型的容器# 创建一个Sequential的对象,顺序模型,多个网络层的线性堆叠
# 可使用add方法将各层添加到模块中
model = keras.models.Sequential()# 添加层次
# 输入层:Flatten将28*28的图像矩阵展平成为一个一维向量
model.add(keras.layers.Flatten(input_shape=[28,28]))# 全连接层(上层所有单元与下层所有单元都连接):
# 第一层300个单元,第二层100个单元,激活函数为 relu:
# relu: y = max(0, x)
model.add(keras.layers.Dense(300,activation="relu"))          
model.add(keras.layers.Dense(100,activation="relu"))# 输出为长度为10的向量,激活函数为 softmax:
# softmax: 将向量变成概率分布,x = [x1, x2, x3],
# y = [e^x1/sum, e^x2/sum, e^x3/sum],sum = e^x1+e^x2+e^x3
model.add(keras.layers.Dense(10,activation="softmax"))# 目标函数的构建与求解方法
# 为什么使用sparse? : 
# y->是一个数,要用sparse_categorical_crossentropy
# y->是一个向量,直接用categorical_crossentropy
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam",metrics = ["accuracy"])
"""
构建模型也可以这样:
model = keras.models.Sequential([keras.layers.Flatten(input_shape=[28,28]),keras.layers.Dense(300,activation="relu"),keras.layers.Dense(300,activation="relu"),keras.layers.Dense(10,activation="softmax")
])"""
  • 查看模型
# 看模型的层情况
model.layers
[<tensorflow.python.keras.layers.core.Flatten at 0x28a9c583088>,<tensorflow.python.keras.layers.core.Dense at 0x28a9c583108>,<tensorflow.python.keras.layers.core.Dense at 0x28a9c5cbec8>,<tensorflow.python.keras.layers.core.Dense at 0x28a9c925f88>]
# 看模型的概况
model.summary()# 参数量:[None,784] * w + b -> [None, 300]:w.shape=[784, 300],b = 300
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 300)               235500    
_________________________________________________________________
dense_1 (Dense)              (None, 100)               30100     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1010      
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
  • 训练
# 开启训练
# epochs:训练集遍历10次
# validation_data:每个epoch就会用验证集验证
# 会发现loss和accuracy到后面一直不变,因为用sgd梯度下降法会导致陷入局部最小值点
# 因此将loss函数的下降方法改为 adam
history = model.fit(x_train, y_train, epochs=10,validation_data=(x_valid, y_valid))
Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 4s 69us/sample - loss: 2.5882 - accuracy: 0.7635 - val_loss: 0.6161 - val_accuracy: 0.8122
Epoch 2/10
55000/55000 [==============================] - 3s 62us/sample - loss: 0.5384 - accuracy: 0.8169 - val_loss: 0.5430 - val_accuracy: 0.8268
Epoch 3/10
55000/55000 [==============================] - 3s 62us/sample - loss: 0.4813 - accuracy: 0.8317 - val_loss: 0.5800 - val_accuracy: 0.8166
Epoch 4/10
55000/55000 [==============================] - 3s 63us/sample - loss: 0.4575 - accuracy: 0.8381 - val_loss: 0.5061 - val_accuracy: 0.8276
Epoch 5/10
55000/55000 [==============================] - 3s 62us/sample - loss: 0.4363 - accuracy: 0.8447 - val_loss: 0.4295 - val_accuracy: 0.8612
Epoch 6/10
55000/55000 [==============================] - 3s 62us/sample - loss: 0.4133 - accuracy: 0.8530 - val_loss: 0.4093 - val_accuracy: 0.8602
Epoch 7/10
55000/55000 [==============================] - 3s 63us/sample - loss: 0.3913 - accuracy: 0.8603 - val_loss: 0.4328 - val_accuracy: 0.8506
Epoch 8/10
55000/55000 [==============================] - 3s 63us/sample - loss: 0.3798 - accuracy: 0.8635 - val_loss: 0.3907 - val_accuracy: 0.8564
Epoch 9/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.3694 - accuracy: 0.8681 - val_loss: 0.4227 - val_accuracy: 0.8564
Epoch 10/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.3625 - accuracy: 0.8696 - val_loss: 0.4136 - val_accuracy: 0.8630
  • 查看训练后的结果
type(history)
tensorflow.python.keras.callbacks.History
history.history
# loss是训练集的损失值,val_loss是测试集的损失值
{'loss': [2.5882461955070495,0.5384013248010115,0.48129710446704516,0.45748968857851896,0.43628633408329703,0.41328840144330803,0.3912581247546456,0.37976206094351683,0.3693601242488081,0.3625139371091669],'accuracy': [0.7634909,0.81685454,0.8316727,0.83805454,0.84467274,0.8529818,0.8602909,0.8634727,0.86814547,0.8695818],'val_loss': [0.6161495039701462,0.542988519859314,0.5799951359272003,0.506083318400383,0.4295428094863892,0.40931380726099015,0.4327934848666191,0.39065408419966696,0.42266337755918504,0.41358628759980204],'val_accuracy': [0.8122,0.8268,0.8166,0.8276,0.8612,0.8602,0.8506,0.8564,0.8564,0.863]}
def plot_learning_curves(history):# 将history.history转换为dataframe格式pd.DataFrame(history.history).plot(figsize=(8, 5 ))plt.grid(True)# gca:get current axes,gcf: get current figureplt.gca().set_ylim(0, 1)plt.show()
plot_learning_curves(history)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PzEQNq8y-1582528839439)(output_10_0.png)]

# 转换为dataframe格式进行查看
pd.DataFrame(history.history)
lossaccuracyval_lossval_accuracy
02.5882460.7634910.6161500.8122
10.5384010.8168550.5429890.8268
20.4812970.8316730.5799950.8166
30.4574900.8380550.5060830.8276
40.4362860.8446730.4295430.8612
50.4132880.8529820.4093140.8602
60.3912580.8602910.4327930.8506
70.3797620.8634730.3906540.8564
80.3693600.8681450.4226630.8564
90.3625140.8695820.4135860.8630

这篇关于Tensorflow2.0学习(2):基于fashion_mnist数据集的分类基本步骤的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/600942

相关文章

IDEA与JDK、Maven安装配置完整步骤解析

《IDEA与JDK、Maven安装配置完整步骤解析》:本文主要介绍如何安装和配置IDE(IntelliJIDEA),包括IDE的安装步骤、JDK的下载与配置、Maven的安装与配置,以及如何在I... 目录1. IDE安装步骤2.配置操作步骤3. JDK配置下载JDK配置JDK环境变量4. Maven配置下

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

MySQL InnoDB引擎ibdata文件损坏/删除后使用frm和ibd文件恢复数据

《MySQLInnoDB引擎ibdata文件损坏/删除后使用frm和ibd文件恢复数据》mysql的ibdata文件被误删、被恶意修改,没有从库和备份数据的情况下的数据恢复,不能保证数据库所有表数据... 参考:mysql Innodb表空间卸载、迁移、装载的使用方法注意!此方法只适用于innodb_fi

mysql通过frm和ibd文件恢复表_mysql5.7根据.frm和.ibd文件恢复表结构和数据

《mysql通过frm和ibd文件恢复表_mysql5.7根据.frm和.ibd文件恢复表结构和数据》文章主要介绍了如何从.frm和.ibd文件恢复MySQLInnoDB表结构和数据,需要的朋友可以参... 目录一、恢复表结构二、恢复表数据补充方法一、恢复表结构(从 .frm 文件)方法 1:使用 mysq

mysql8.0无备份通过idb文件恢复数据的方法、idb文件修复和tablespace id不一致处理

《mysql8.0无备份通过idb文件恢复数据的方法、idb文件修复和tablespaceid不一致处理》文章描述了公司服务器断电后数据库故障的过程,作者通过查看错误日志、重新初始化数据目录、恢复备... 周末突然接到一位一年多没联系的妹妹打来电话,“刘哥,快来救救我”,我脑海瞬间冒出妙瓦底,电信火苲马扁.

golang获取prometheus数据(prometheus/client_golang包)

《golang获取prometheus数据(prometheus/client_golang包)》本文主要介绍了使用Go语言的prometheus/client_golang包来获取Prometheu... 目录1. 创建链接1.1 语法1.2 完整示例2. 简单查询2.1 语法2.2 完整示例3. 范围值

springboot rocketmq配置生产者和消息者的步骤

《springbootrocketmq配置生产者和消息者的步骤》本文介绍了如何在SpringBoot中集成RocketMQ,包括添加依赖、配置application.yml、创建生产者和消费者,并展... 目录1. 添加依赖2. 配置application.yml3. 创建生产者4. 创建消费者5. 使用在

javaScript在表单提交时获取表单数据的示例代码

《javaScript在表单提交时获取表单数据的示例代码》本文介绍了五种在JavaScript中获取表单数据的方法:使用FormData对象、手动提取表单数据、使用querySelector获取单个字... 方法 1:使用 FormData 对象FormData 是一个方便的内置对象,用于获取表单中的键值

mac安装nvm(node.js)多版本管理实践步骤

《mac安装nvm(node.js)多版本管理实践步骤》:本文主要介绍mac安装nvm(node.js)多版本管理的相关资料,NVM是一个用于管理多个Node.js版本的命令行工具,它允许开发者在... 目录NVM功能简介MAC安装实践一、下载nvm二、安装nvm三、安装node.js总结NVM功能简介N

Python中多线程和多进程的基本用法详解

《Python中多线程和多进程的基本用法详解》这篇文章介绍了Python中多线程和多进程的相关知识,包括并发编程的优势,多线程和多进程的概念、适用场景、示例代码,线程池和进程池的使用,以及如何选择合适... 目录引言一、并发编程的主要优势二、python的多线程(Threading)1. 什么是多线程?2.