MINIST数据集测试不同参数对网络的影响

2024-03-18 11:08

本文主要是介绍MINIST数据集测试不同参数对网络的影响,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 一.介绍
    • 1.实验环境
    • 2.网络结构
  • 二.网络效果
    • 1.初始状态
    • 2.有BN层
    • 3.激活函数
      • tanh
      • sigmoid
      • relu
    • 4. 正则化
      • L2正则化
      • Dropout
    • 5.优化器
    • 6. 学习率衰减
  • 三.最优测试
  • 附: 完整代码

一.介绍

本实验使用两个不同的神经网络,通过MINIST数据集进行训练,查看不同情况下最后的效果。

1.实验环境

  1. Python 3.8
  2. Pytorch 1.8
  3. Pycharm

2.网络结构

单层卷积:一层卷积+一层池化+两层全连接

class Net_1(nn.Module):def __init__(self):super(Net_1, self).__init__()self.model = nn.Sequential(nn.Conv2d(1,10,kernel_size=3,stride=1),#nn.BatchNorm2d(10),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(10*13*13, 120),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(120, 10),)def forward(self,x):return self.model(x)

多层网络:三层卷积+三层池化+两层全连接

class Net_2(nn.Module):def __init__(self):super(Net_2, self).__init__()self.model = nn.Sequential(# out-> [40,28,28]nn.Conv2d(1, 40, kernel_size=3, stride=1,padding=1),#nn.BatchNorm2d(40),nn.ReLU(),# out->[40,14,14]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,12,12]nn.Conv2d(40, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,6,6]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,4,4]nn.Conv2d(20, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,2,2]nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(2 * 2 * 20, 100),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(100, 10),)def forward(self, x):return self.model(x)

二.网络效果

1.初始状态

Batchsize :200
Learning_rate:0.001
Epochs:15
优化器:torch.optim.SGD
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
以下变更均建立在此基础之上
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

2.有BN层

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0700.0690.9784
多层网络0.0400.0480.9839

添加BN层结果有所好转

3.激活函数

tanh

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2100.2020.9417
多层网络0.1350.1220.9675

单层多层的表现都不如relu,比较接近

sigmoid

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.7400.6930.8436
多层网络2.3012.3010.1135

relu

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

4. 正则化

L2正则化

optimizer=optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.001)
weight_decay设置为0.001
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1290.1200.9646
多层网络0.0960.0810.974

Dropout

nn.Dropout(p=0.5)
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2010.1280.9617
多层网络0.1620.0860.9718

5.优化器

将SGD优化器变更为Adam优化器
单层网络效果:
在这里插入图片描述
多层网络效果:

在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0070.0530.9836
多层网络0.0200.0340.9889

6. 学习率衰减

scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)此处使用Adam优化器,weight_decay为0.001,添加BN层,Dropout层,使用Relu函数
单层网络效果:

在这里插入图片描述
多层网络效果:
在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0580.0430.9853
多层网络0.0240.0220.9929

三.最优测试

Batchsize :200
Learning_rate:0.01
Epochs:15
优化器:torch.optim.Adam,weight_decay为0.001
学习率衰减:step_size=2, gamma=0.2
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
添加BN,dropout层
多层网络效果:
在这里插入图片描述

train_losstest_lossaccuracy
多层网络0.0310.0210.9925

附: 完整代码

import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optimdef main():batch_size = 200learning_rate = 0.01epochs = 50train_loader = DataLoader(datasets.MNIST('MNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)test_loader = DataLoader(datasets.MNIST('MNIST', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)device = torch.device('cuda')model = Net_2().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.2)train_loss = []test_loss = []acc = []for epoch in range(epochs):loss_1 = 0.loss_2 = 0.model.train()for i, (x, label) in enumerate(train_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)optimizer.zero_grad()loss.backward()optimizer.step()loss_1 += loss.item()scheduler_step.step()loss_1 = loss_1 / (i + 1)train_loss.append(loss_1)print('train: 第{}次,loss为{}'.format(epoch, loss_1))model.eval()with torch.no_grad():correct = 0.for i, (x, label) in enumerate(test_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)pred = out.argmax(dim=1)correct += torch.eq(pred, label).sum().item()loss_2 += loss.item()accuracy = correct / len(test_loader.dataset)loss_2 = loss_2 / (i + 1)test_loss.append(loss_2)acc.append(accuracy)print('test: 第{}次,loss为{},accuracy为{}'.format(epoch, loss_2, accuracy))plt.figure(num=1, figsize=(10, 5.4))plt.subplot(121)# plt.title("loss")plt.plot(train_loss, 'b-', label='train_loss')plt.plot(test_loss, 'g-', label='test_loss')plt.xlabel('epoch')plt.ylabel('loss')plt.legend()plt.subplot(122)plt.plot(acc, 'g-', label='accuracy')plt.xlabel('epoch')plt.ylabel('accuracy')plt.legend()plt.show()if __name__ == '__main__':main()

这篇关于MINIST数据集测试不同参数对网络的影响的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822143

相关文章

MySQL大表数据的分区与分库分表的实现

《MySQL大表数据的分区与分库分表的实现》数据库的分区和分库分表是两种常用的技术方案,本文主要介绍了MySQL大表数据的分区与分库分表的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. mysql大表数据的分区1.1 什么是分区?1.2 分区的类型1.3 分区的优点1.4 分

Mysql删除几亿条数据表中的部分数据的方法实现

《Mysql删除几亿条数据表中的部分数据的方法实现》在MySQL中删除一个大表中的数据时,需要特别注意操作的性能和对系统的影响,本文主要介绍了Mysql删除几亿条数据表中的部分数据的方法实现,具有一定... 目录1、需求2、方案1. 使用 DELETE 语句分批删除2. 使用 INPLACE ALTER T

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Redis 中的热点键和数据倾斜示例详解

《Redis中的热点键和数据倾斜示例详解》热点键是指在Redis中被频繁访问的特定键,这些键由于其高访问频率,可能导致Redis服务器的性能问题,尤其是在高并发场景下,本文给大家介绍Redis中的热... 目录Redis 中的热点键和数据倾斜热点键(Hot Key)定义特点应对策略示例数据倾斜(Data S

MySQL中慢SQL优化的不同方式介绍

《MySQL中慢SQL优化的不同方式介绍》慢SQL的优化,主要从两个方面考虑,SQL语句本身的优化,以及数据库设计的优化,下面小编就来给大家介绍一下有哪些方式可以优化慢SQL吧... 目录避免不必要的列分页优化索引优化JOIN 的优化排序优化UNION 优化慢 SQL 的优化,主要从两个方面考虑,SQL 语

Python实现将MySQL中所有表的数据都导出为CSV文件并压缩

《Python实现将MySQL中所有表的数据都导出为CSV文件并压缩》这篇文章主要为大家详细介绍了如何使用Python将MySQL数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到... python将mysql数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到另一个

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

一文带你了解SpringBoot中启动参数的各种用法

《一文带你了解SpringBoot中启动参数的各种用法》在使用SpringBoot开发应用时,我们通常需要根据不同的环境或特定需求调整启动参数,那么,SpringBoot提供了哪些方式来配置这些启动参... 目录一、启动参数的常见传递方式二、通过命令行参数传递启动参数三、使用 application.pro

SpringBoot整合jasypt实现重要数据加密

《SpringBoot整合jasypt实现重要数据加密》Jasypt是一个专注于简化Java加密操作的开源工具,:本文主要介绍详细介绍了如何使用jasypt实现重要数据加密,感兴趣的小伙伴可... 目录jasypt简介 jasypt的优点SpringBoot使用jasypt创建mapper接口配置文件加密