MINIST数据集测试不同参数对网络的影响

2024-03-18 11:08

本文主要是介绍MINIST数据集测试不同参数对网络的影响,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 一.介绍
    • 1.实验环境
    • 2.网络结构
  • 二.网络效果
    • 1.初始状态
    • 2.有BN层
    • 3.激活函数
      • tanh
      • sigmoid
      • relu
    • 4. 正则化
      • L2正则化
      • Dropout
    • 5.优化器
    • 6. 学习率衰减
  • 三.最优测试
  • 附: 完整代码

一.介绍

本实验使用两个不同的神经网络,通过MINIST数据集进行训练,查看不同情况下最后的效果。

1.实验环境

  1. Python 3.8
  2. Pytorch 1.8
  3. Pycharm

2.网络结构

单层卷积:一层卷积+一层池化+两层全连接

class Net_1(nn.Module):def __init__(self):super(Net_1, self).__init__()self.model = nn.Sequential(nn.Conv2d(1,10,kernel_size=3,stride=1),#nn.BatchNorm2d(10),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(10*13*13, 120),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(120, 10),)def forward(self,x):return self.model(x)

多层网络:三层卷积+三层池化+两层全连接

class Net_2(nn.Module):def __init__(self):super(Net_2, self).__init__()self.model = nn.Sequential(# out-> [40,28,28]nn.Conv2d(1, 40, kernel_size=3, stride=1,padding=1),#nn.BatchNorm2d(40),nn.ReLU(),# out->[40,14,14]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,12,12]nn.Conv2d(40, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,6,6]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,4,4]nn.Conv2d(20, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,2,2]nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(2 * 2 * 20, 100),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(100, 10),)def forward(self, x):return self.model(x)

二.网络效果

1.初始状态

Batchsize :200
Learning_rate:0.001
Epochs:15
优化器:torch.optim.SGD
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
以下变更均建立在此基础之上
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

2.有BN层

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0700.0690.9784
多层网络0.0400.0480.9839

添加BN层结果有所好转

3.激活函数

tanh

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2100.2020.9417
多层网络0.1350.1220.9675

单层多层的表现都不如relu,比较接近

sigmoid

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.7400.6930.8436
多层网络2.3012.3010.1135

relu

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

4. 正则化

L2正则化

optimizer=optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.001)
weight_decay设置为0.001
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1290.1200.9646
多层网络0.0960.0810.974

Dropout

nn.Dropout(p=0.5)
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2010.1280.9617
多层网络0.1620.0860.9718

5.优化器

将SGD优化器变更为Adam优化器
单层网络效果:
在这里插入图片描述
多层网络效果:

在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0070.0530.9836
多层网络0.0200.0340.9889

6. 学习率衰减

scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)此处使用Adam优化器,weight_decay为0.001,添加BN层,Dropout层,使用Relu函数
单层网络效果:

在这里插入图片描述
多层网络效果:
在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0580.0430.9853
多层网络0.0240.0220.9929

三.最优测试

Batchsize :200
Learning_rate:0.01
Epochs:15
优化器:torch.optim.Adam,weight_decay为0.001
学习率衰减:step_size=2, gamma=0.2
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
添加BN,dropout层
多层网络效果:
在这里插入图片描述

train_losstest_lossaccuracy
多层网络0.0310.0210.9925

附: 完整代码

import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optimdef main():batch_size = 200learning_rate = 0.01epochs = 50train_loader = DataLoader(datasets.MNIST('MNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)test_loader = DataLoader(datasets.MNIST('MNIST', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)device = torch.device('cuda')model = Net_2().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.2)train_loss = []test_loss = []acc = []for epoch in range(epochs):loss_1 = 0.loss_2 = 0.model.train()for i, (x, label) in enumerate(train_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)optimizer.zero_grad()loss.backward()optimizer.step()loss_1 += loss.item()scheduler_step.step()loss_1 = loss_1 / (i + 1)train_loss.append(loss_1)print('train: 第{}次,loss为{}'.format(epoch, loss_1))model.eval()with torch.no_grad():correct = 0.for i, (x, label) in enumerate(test_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)pred = out.argmax(dim=1)correct += torch.eq(pred, label).sum().item()loss_2 += loss.item()accuracy = correct / len(test_loader.dataset)loss_2 = loss_2 / (i + 1)test_loss.append(loss_2)acc.append(accuracy)print('test: 第{}次,loss为{},accuracy为{}'.format(epoch, loss_2, accuracy))plt.figure(num=1, figsize=(10, 5.4))plt.subplot(121)# plt.title("loss")plt.plot(train_loss, 'b-', label='train_loss')plt.plot(test_loss, 'g-', label='test_loss')plt.xlabel('epoch')plt.ylabel('loss')plt.legend()plt.subplot(122)plt.plot(acc, 'g-', label='accuracy')plt.xlabel('epoch')plt.ylabel('accuracy')plt.legend()plt.show()if __name__ == '__main__':main()

这篇关于MINIST数据集测试不同参数对网络的影响的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822143

相关文章

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

Python MySQL如何通过Binlog获取变更记录恢复数据

《PythonMySQL如何通过Binlog获取变更记录恢复数据》本文介绍了如何使用Python和pymysqlreplication库通过MySQL的二进制日志(Binlog)获取数据库的变更记录... 目录python mysql通过Binlog获取变更记录恢复数据1.安装pymysqlreplicat

Linux使用dd命令来复制和转换数据的操作方法

《Linux使用dd命令来复制和转换数据的操作方法》Linux中的dd命令是一个功能强大的数据复制和转换实用程序,它以较低级别运行,通常用于创建可启动的USB驱动器、克隆磁盘和生成随机数据等任务,本文... 目录简介功能和能力语法常用选项示例用法基础用法创建可启动www.chinasem.cn的 USB 驱动

java脚本使用不同版本jdk的说明介绍

《java脚本使用不同版本jdk的说明介绍》本文介绍了在Java中执行JavaScript脚本的几种方式,包括使用ScriptEngine、Nashorn和GraalVM,ScriptEngine适用... 目录Java脚本使用不同版本jdk的说明1.使用ScriptEngine执行javascript2.

Python如何使用seleniumwire接管Chrome查看控制台中参数

《Python如何使用seleniumwire接管Chrome查看控制台中参数》文章介绍了如何使用Python的seleniumwire库来接管Chrome浏览器,并通过控制台查看接口参数,本文给大家... 1、cmd打开控制台,启动谷歌并制定端口号,找不到文件的加环境变量chrome.exe --rem

Oracle数据库使用 listagg去重删除重复数据的方法汇总

《Oracle数据库使用listagg去重删除重复数据的方法汇总》文章介绍了在Oracle数据库中使用LISTAGG和XMLAGG函数进行字符串聚合并去重的方法,包括去重聚合、使用XML解析和CLO... 目录案例表第一种:使用wm_concat() + distinct去重聚合第二种:使用listagg,

Python实现将实体类列表数据导出到Excel文件

《Python实现将实体类列表数据导出到Excel文件》在数据处理和报告生成中,将实体类的列表数据导出到Excel文件是一项常见任务,Python提供了多种库来实现这一目标,下面就来跟随小编一起学习一... 目录一、环境准备二、定义实体类三、创建实体类列表四、将实体类列表转换为DataFrame五、导出Da

Python实现数据清洗的18种方法

《Python实现数据清洗的18种方法》本文主要介绍了Python实现数据清洗的18种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录1. 去除字符串两边空格2. 转换数据类型3. 大小写转换4. 移除列表中的重复元素5. 快速统

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档

在Pandas中进行数据重命名的方法示例

《在Pandas中进行数据重命名的方法示例》Pandas作为Python中最流行的数据处理库,提供了强大的数据操作功能,其中数据重命名是常见且基础的操作之一,本文将通过简洁明了的讲解和丰富的代码示例,... 目录一、引言二、Pandas rename方法简介三、列名重命名3.1 使用字典进行列名重命名3.编