MINIST数据集测试不同参数对网络的影响

2024-03-18 11:08

本文主要是介绍MINIST数据集测试不同参数对网络的影响,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 一.介绍
    • 1.实验环境
    • 2.网络结构
  • 二.网络效果
    • 1.初始状态
    • 2.有BN层
    • 3.激活函数
      • tanh
      • sigmoid
      • relu
    • 4. 正则化
      • L2正则化
      • Dropout
    • 5.优化器
    • 6. 学习率衰减
  • 三.最优测试
  • 附: 完整代码

一.介绍

本实验使用两个不同的神经网络,通过MINIST数据集进行训练,查看不同情况下最后的效果。

1.实验环境

  1. Python 3.8
  2. Pytorch 1.8
  3. Pycharm

2.网络结构

单层卷积:一层卷积+一层池化+两层全连接

class Net_1(nn.Module):def __init__(self):super(Net_1, self).__init__()self.model = nn.Sequential(nn.Conv2d(1,10,kernel_size=3,stride=1),#nn.BatchNorm2d(10),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(10*13*13, 120),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(120, 10),)def forward(self,x):return self.model(x)

多层网络:三层卷积+三层池化+两层全连接

class Net_2(nn.Module):def __init__(self):super(Net_2, self).__init__()self.model = nn.Sequential(# out-> [40,28,28]nn.Conv2d(1, 40, kernel_size=3, stride=1,padding=1),#nn.BatchNorm2d(40),nn.ReLU(),# out->[40,14,14]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,12,12]nn.Conv2d(40, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,6,6]nn.MaxPool2d(kernel_size=2, stride=2),# out->[20,4,4]nn.Conv2d(20, 20, kernel_size=3, stride=1),#nn.BatchNorm2d(20),nn.ReLU(),# out->[20,2,2]nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(2 * 2 * 20, 100),nn.ReLU(),#nn.Dropout(p=0.5),nn.Linear(100, 10),)def forward(self, x):return self.model(x)

二.网络效果

1.初始状态

Batchsize :200
Learning_rate:0.001
Epochs:15
优化器:torch.optim.SGD
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
以下变更均建立在此基础之上
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

2.有BN层

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0700.0690.9784
多层网络0.0400.0480.9839

添加BN层结果有所好转

3.激活函数

tanh

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2100.2020.9417
多层网络0.1350.1220.9675

单层多层的表现都不如relu,比较接近

sigmoid

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.7400.6930.8436
多层网络2.3012.3010.1135

relu

单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1360.1280.9634
多层网络0.0900.0800.9734

4. 正则化

L2正则化

optimizer=optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.001)
weight_decay设置为0.001
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.1290.1200.9646
多层网络0.0960.0810.974

Dropout

nn.Dropout(p=0.5)
单层网络效果:
在这里插入图片描述

多层网络效果:
在这里插入图片描述

最后一次的结果:

train_losstest_lossaccuracy
单层网络0.2010.1280.9617
多层网络0.1620.0860.9718

5.优化器

将SGD优化器变更为Adam优化器
单层网络效果:
在这里插入图片描述
多层网络效果:

在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0070.0530.9836
多层网络0.0200.0340.9889

6. 学习率衰减

scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)此处使用Adam优化器,weight_decay为0.001,添加BN层,Dropout层,使用Relu函数
单层网络效果:

在这里插入图片描述
多层网络效果:
在这里插入图片描述
最后一次的结果:

train_losstest_lossaccuracy
单层网络0.0580.0430.9853
多层网络0.0240.0220.9929

三.最优测试

Batchsize :200
Learning_rate:0.01
Epochs:15
优化器:torch.optim.Adam,weight_decay为0.001
学习率衰减:step_size=2, gamma=0.2
损失函数:nn.CrossEntropyLoss()
激活函数:Relu
添加BN,dropout层
多层网络效果:
在这里插入图片描述

train_losstest_lossaccuracy
多层网络0.0310.0210.9925

附: 完整代码

import matplotlib.pyplot as plt
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optimdef main():batch_size = 200learning_rate = 0.01epochs = 50train_loader = DataLoader(datasets.MNIST('MNIST', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)test_loader = DataLoader(datasets.MNIST('MNIST', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])),batch_size=batch_size,shuffle=True)device = torch.device('cuda')model = Net_2().to(device)criteon = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)scheduler_step = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.2)train_loss = []test_loss = []acc = []for epoch in range(epochs):loss_1 = 0.loss_2 = 0.model.train()for i, (x, label) in enumerate(train_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)optimizer.zero_grad()loss.backward()optimizer.step()loss_1 += loss.item()scheduler_step.step()loss_1 = loss_1 / (i + 1)train_loss.append(loss_1)print('train: 第{}次,loss为{}'.format(epoch, loss_1))model.eval()with torch.no_grad():correct = 0.for i, (x, label) in enumerate(test_loader):x, label = x.to(device), label.to(device)out = model(x)loss = criteon(out, label)pred = out.argmax(dim=1)correct += torch.eq(pred, label).sum().item()loss_2 += loss.item()accuracy = correct / len(test_loader.dataset)loss_2 = loss_2 / (i + 1)test_loss.append(loss_2)acc.append(accuracy)print('test: 第{}次,loss为{},accuracy为{}'.format(epoch, loss_2, accuracy))plt.figure(num=1, figsize=(10, 5.4))plt.subplot(121)# plt.title("loss")plt.plot(train_loss, 'b-', label='train_loss')plt.plot(test_loss, 'g-', label='test_loss')plt.xlabel('epoch')plt.ylabel('loss')plt.legend()plt.subplot(122)plt.plot(acc, 'g-', label='accuracy')plt.xlabel('epoch')plt.ylabel('accuracy')plt.legend()plt.show()if __name__ == '__main__':main()

这篇关于MINIST数据集测试不同参数对网络的影响的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/822143

相关文章

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

Python如何计算两个不同类型列表的相似度

《Python如何计算两个不同类型列表的相似度》在编程中,经常需要比较两个列表的相似度,尤其是当这两个列表包含不同类型的元素时,下面小编就来讲讲如何使用Python计算两个不同类型列表的相似度吧... 目录摘要引言数字类型相似度欧几里得距离曼哈顿距离字符串类型相似度Levenshtein距离Jaccard相

在不同系统间迁移Python程序的方法与教程

《在不同系统间迁移Python程序的方法与教程》本文介绍了几种将Windows上编写的Python程序迁移到Linux服务器上的方法,包括使用虚拟环境和依赖冻结、容器化技术(如Docker)、使用An... 目录使用虚拟环境和依赖冻结1. 创建虚拟环境2. 冻结依赖使用容器化技术(如 docker)1. 创

关于Spring @Bean 相同加载顺序不同结果不同的问题记录

《关于Spring@Bean相同加载顺序不同结果不同的问题记录》本文主要探讨了在Spring5.1.3.RELEASE版本下,当有两个全注解类定义相同类型的Bean时,由于加载顺序不同,最终生成的... 目录问题说明测试输出1测试输出2@Bean注解的BeanDefiChina编程nition加入时机总结问题说明

Redis的数据过期策略和数据淘汰策略

《Redis的数据过期策略和数据淘汰策略》本文主要介绍了Redis的数据过期策略和数据淘汰策略,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录一、数据过期策略1、惰性删除2、定期删除二、数据淘汰策略1、数据淘汰策略概念2、8种数据淘汰策略

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

Java通过反射获取方法参数名的方式小结

《Java通过反射获取方法参数名的方式小结》这篇文章主要为大家详细介绍了Java如何通过反射获取方法参数名的方式,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、前言2、解决方式方式2.1: 添加编译参数配置 -parameters方式2.2: 使用Spring的内部工具类 -

Python给Excel写入数据的四种方法小结

《Python给Excel写入数据的四种方法小结》本文主要介绍了Python给Excel写入数据的四种方法小结,包含openpyxl库、xlsxwriter库、pandas库和win32com库,具有... 目录1. 使用 openpyxl 库2. 使用 xlsxwriter 库3. 使用 pandas 库

SpringBoot定制JSON响应数据的实现

《SpringBoot定制JSON响应数据的实现》本文主要介绍了SpringBoot定制JSON响应数据的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们... 目录前言一、如何使用@jsonView这个注解?二、应用场景三、实战案例注解方式编程方式总结 前言