一起深度学习——CIFAR10

2024-05-09 22:52
文章标签 学习 深度 一起 cifar10

本文主要是介绍一起深度学习——CIFAR10,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CIFAR10

  • 目的:
  • 实现步骤:
    • 1、导包:
    • 2、下载数据集
    • 3、整理数据集
    • 4、将验证集从原始的训练集中拆分出来
    • 5、数据增强
    • 6、加载数据集
    • 7、定义训练模型:
    • 8、定义训练函数:
    • 9、定义参数,开始训练:

目的:

实现从数据集中进行分类,一共有10个类别。

实现步骤:

1、导包:

import collectionsimport torch
from torch import nn
from d2l import torch as d2l
import shutil
import os
import math
import pandas as pd
import torchvision

2、下载数据集

#下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = Trueif demo:data_dir = d2l.download_extract('cifar10_tiny')
else:data_dir = '../data/kaggle/cifar-10/'

3、整理数据集

ef read_csv_labels (fname):with open(fname,'r') as f:lines = f.readlines()[1:] #[1:0]表示从第二行开始读取,因为第一行是行头# 按照逗号分割每一行, 且rstrip 去除每行末尾的换行符# eg: ["apple,orange,banana\n"] => [['apple','orange','banana']]tokens = [l.rstrip().split(',') for l in lines]return dict((name,label) for name, label in tokens)labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))

4、将验证集从原始的训练集中拆分出来

def copyfile(filename,target_dir):os.makedirs(target_dir,exist_ok=True)shutil.copy(filename,target_dir)# print("训练样本:",len(labels))
# print("类别:",len(set(labels.values())))
def reorg_train_valid(data_dir,labels,valid_ratio):# Counter :用于计算每个类别出现的次数# most_common() :用于统计返回出现次数最多的元素(类别,次数),是一个列表,并且按照次数降序的方式存储# 【-1】表示取出列表中的最后一个元组。# 【1】 表示取出该元组的次数。n = collections.Counter(labels.values()).most_common()[-1][1]# math.floor(): 向下取整  math.ceil(): 向上取整# 每个类别分配给验证集的最小个数n_valid_per_label = max(1,math.floor((n * valid_ratio)))label_count = {}for photo in os.listdir(os.path.join(data_dir,'train')):label = labels[photo.split('.')[0]] #取出该标签所对应的类别# print(train_file)fname = os.path.join(data_dir,'train',photo)copyfile(fname,os.path.join(data_dir,'train_valid_test','train_valid',label))#如果该类别没有在label_count中或者是 数量小于规定的最小值,则将其复制到验证集中if label not in label_count or label_count[label] < n_valid_per_label:copyfile(fname,os.path.join(data_dir,'train_valid_test','valid',label))label_count[label] = label_count.get(label,0) +1else:copyfile(fname,os.path.join(data_dir,'train_valid_test','train',label))return n_valid_per_labeldef reorg_test(data_dir):for test_file in os.listdir(os.path.join(data_dir,'test')):copyfile(os.path.join(data_dir,'test',test_file), os.path.join(data_dir,'train_valid_test','test','unknown'))def reorg_cifar10_data(data_dir,valid_ratio):labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))reorg_train_valid(data_dir,labels,valid_ratio)reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir,valid_ratio)
"""
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
"""

5、数据增强

transform_train = torchvision.transforms.Compose([# 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强torchvision.transforms.Resize(40),torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),# 标准化图像的每个通道 : 消除评估结果中的随机性torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])

6、加载数据集

#加载数据集
train_ds,train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir,'train_valid_test',folder),transform=transform_train)for folder in ['train','train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test) for folder in ['valid', 'test']
]
#定义迭代器
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]valid_iter = torch.utils.data.DataLoader(valid_ds,batch_size,shuffle=False,drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds,batch_size,shuffle=False,drop_last=False
)

7、定义训练模型:

def get_net():num_classes = 10 #输出标签net = d2l.resnet18(num_classes,in_channels=3)return net

损失函数:

#损失函数
loss = nn.CrossEntropyLoss(reduction='none')

8、定义训练函数:

#定义训练函数
def train(net,train_iter,valid_iter,num_epochs,lr,wd,devices,lr_period,lr_decay):trainer = torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=wd)#学习率调度器:在经过lr_period个epoch之后,将学习率乘以lr_decay.scheduler = torch.optim.lr_scheduler.StepLR(trainer,lr_period,lr_decay)num_batches,timer = len(train_iter),d2l.Timer()legend = ['train_loss','train_acc']if valid_iter is not None:legend.append('valid_acc')animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=legend)net = nn.DataParallel(net,device_ids=devices).to(devices[0])for epoch in range(num_epochs):# 设置为训练模式net.train()metric = d2l.Accumulator(3)for i,(X,y) in enumerate(train_iter):timer.start()l,acc = d2l.train_batch_ch13(net,X,y,loss,trainer,devices)metric.add(l,acc,y.shape[0])timer.stop()if (i + 1) % (num_batches // 5 ) ==0 or i == num_batches - 1:animator.add(epoch + (i + 1)/num_batches,(metric[0]/metric[2],metric[1]/metric[2],None))if valid_iter is not None:valid_acc = d2l.evaluate_accuracy_gpu(net,valid_iter)animator.add(epoch+1,(None,None,valid_acc))scheduler.step()  #更新学习率measures = (f'train loss {metric[0] / metric[2]:.3f},'f'train acc{metric[1] / metric[2]:.3f}')if valid_iter is not None:measures += f', valid acc {valid_acc:.3f}'print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'f'example/sec on {str(devices)}')

9、定义参数,开始训练:


import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 100, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

这篇关于一起深度学习——CIFAR10的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/974738

相关文章

SpringCloud动态配置注解@RefreshScope与@Component的深度解析

《SpringCloud动态配置注解@RefreshScope与@Component的深度解析》在现代微服务架构中,动态配置管理是一个关键需求,本文将为大家介绍SpringCloud中相关的注解@Re... 目录引言1. @RefreshScope 的作用与原理1.1 什么是 @RefreshScope1.

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Redis中高并发读写性能的深度解析与优化

《Redis中高并发读写性能的深度解析与优化》Redis作为一款高性能的内存数据库,广泛应用于缓存、消息队列、实时统计等场景,本文将深入探讨Redis的读写并发能力,感兴趣的小伙伴可以了解下... 目录引言一、Redis 并发能力概述1.1 Redis 的读写性能1.2 影响 Redis 并发能力的因素二、

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

Redis 内存淘汰策略深度解析(最新推荐)

《Redis内存淘汰策略深度解析(最新推荐)》本文详细探讨了Redis的内存淘汰策略、实现原理、适用场景及最佳实践,介绍了八种内存淘汰策略,包括noeviction、LRU、LFU、TTL、Rand... 目录一、 内存淘汰策略概述二、内存淘汰策略详解2.1 ​noeviction(不淘汰)​2.2 ​LR

Python与DeepSeek的深度融合实战

《Python与DeepSeek的深度融合实战》Python作为最受欢迎的编程语言之一,以其简洁易读的语法、丰富的库和广泛的应用场景,成为了无数开发者的首选,而DeepSeek,作为人工智能领域的新星... 目录一、python与DeepSeek的结合优势二、模型训练1. 数据准备2. 模型架构与参数设置3

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操