cifar10在resnet50上精度达95%以上

2024-08-23 15:52

本文主要是介绍cifar10在resnet50上精度达95%以上,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、tensorflow环境

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timeimport tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical(x_train, y_train), (x_test, y_test) =tf.keras.datasets.cifar10.load_data()
y_train = to_categorical(y_train,num_classes=10)
y_test = to_categorical(y_test,num_classes=10)def preprocess(image,lable):resize_image = tf.image.resize(image,[224,224])resize_image = tf.image.random_flip_left_right(resize_image)resize_image = tf.cast(resize_image, tf.float32)/255.mean = [0.5, 0.45, 0.4]std = [0.23, 0.22, 0.21]resize_image = (resize_image - mean)/stdreturn resize_image,labledef preprocess_dataset(x, y, batch_size, seed=42):dataset = tf.data.Dataset.from_tensor_slices((x, y))dataset = dataset.shuffle(buffer_size=batch_size,seed=seed)dataset = dataset.map(preprocess).batch(batch_size)dataset = dataset.prefetch(1)return datasettrain_set = preprocess_dataset(x_train,y_train,16)
test_set = preprocess_dataset(x_test, y_test,4)base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = Truemodel = keras.models.Sequential([base_model,keras.layers.GlobalAveragePooling2D(),keras.layers.Dense(10, activation='softmax')
])optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
loss_function = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss_function, metrics=['accuracy'])history = model.fit(train_set, epochs=5)model.evaluate(test_set)

结果如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2、pytorch环境

import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch import optimmytrans = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize((0.5, 0.45, 0.4), (0.23, 0.22, 0.21))])train_set = torchvision.datasets.CIFAR10(root='./cifar-10-python/', train=True, download=True, transform=mytrans)
train_set = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True, num_workers=0)test_set = torchvision.datasets.CIFAR10(root='./cifar-10-python/', train=False, download=True, transform=mytrans)
test_set = torch.utils.data.DataLoader(test_set, batch_size=16, shuffle=True, num_workers=0)model = torchvision.models.resnet50(weights=True)
nums = model.fc.in_features
model.fc = nn.Linear(nums, 10)device = torch.device("cuda:0")
# device = torch.device("cpu")
model = model.to(device)myoptim = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
myloss = nn.CrossEntropyLoss()n_epoch = 10
for epoch in range(1,n_epoch+1):print("epoch {}/{}".format(epoch,n_epoch))for step,test_data in enumerate(train_set):image, label = test_data[0].to(device), test_data[1].to(device)predict_label = model.forward(image)     loss = myloss(predict_label, label)myoptim.zero_grad()loss.backward()myoptim.step()end = "" if step != len(train_set)-1 else "\n"print("\rtrain iteration {}/{} - training loss: {:.3f}".format(step+1, len(train_set), loss.item()), end=end)model.eval()total = 0correct = 0for step,test_data in enumerate(test_set):image = test_data[0].to(device)label = test_data[1].to(device)outputs = model(image)_, pred = torch.max(outputs.data,1)total += label.size(0)correct += (pred==label).sum().item()end = "" if step != len(test_set)-1 else "\n"print('\rtest iteration {}/{} - testing accuracy: {:.3f}%'.format(step+1, len(test_set), 100*correct/total), end=end)

结果如下:

在这里插入图片描述

这篇关于cifar10在resnet50上精度达95%以上的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1099825

相关文章

Android6.0以上权限申请

说明: 部分1:出自:http://jijiaxin89.com/2015/08/30/Android-s-Runtime-Permission/ android M 的名字官方刚发布不久,最终正式版即将来临! android在不断发展,最近的更新 M 非常不同,一些主要的变化例如运行时权限将有颠覆性影响。惊讶的是android社区鲜有谈论这事儿,尽管这事很重要或许在不远的将来会引

Ubuntu 16.04安装python3.6及其以上版本

Ubuntu16.04 自带python2.7 和3.5,若需要安装高版本需要添加源 网络搜索几个源 sudo add-apt-repository ppa:jonathonf/python-3.6sudo apt-get updatesudo apt-get install python3.6 这个已不存在 会报错如下 Cannot add PPA: 'ppa:~jonathonf

在Webmin上默认状态无法正常显示 Mariadb V11.02及以上版本

OS: Armbian OS 24.5.0 Bookworm Mariadb V11.02及以上版本 Webmin:V2.202 小众问题,主要是记录一下。 如题 Webmin 默认无法 Mariadb V11.02及以上版本 如果对 /etc/webmin/mysql/config 文件作相应调整就可以再现Mariadb管理界面。 路径+文件:/etc/webmin/mysql/config

Mongodb最新版本安装(4.0以上)

最近学习Mongodb数据库 总结了一下心得分享给大家 一,首先需要去官网下载Mongodb  网址https://www.mongodb.com/download-center/community   如下图所示 选择版本号、对应的操作系统、安装包后 点击download  开始下载  下载完成后双击安装就行 安装步骤 1.双击之后如图所示 直接next 2.由于笔者没有保存这张

《长得太长也是错?——后端 Long 型 ID 精度丢失的“奇妙”修复之旅》

引言 在前后端分离的时代,我们的生活充满了无数的机遇与挑战——包括那些突然冒出来的让人抓狂的 Bug。今天我们要聊的,就是一个让无数开发者哭笑不得的经典问题:后端 Long 类型 ID 过长导致前端精度丢失。说到这个问题,那可真是“万恶之源”啊,谁让 JavaScript 只能安全地处理 Number.MAX_SAFE_INTEGER(也就是 9007199254740991)以内的数值呢?

关于精度的问题

在一些问题中经常会遇到一些关于精度的保留; 1.要求保留小数后n位小数:代码如下 #include<stdio.h>int main(){double num = 1.123456789;int n = 6;printf("%0.*lf\n",n,num);    //1.123457return 0;} 注意,他会在小数点第n+1位四舍五入; 2.要求截取小数后n位,也就是不四舍五

蓝牙技术|超高精度蓝牙位置服务将成为蓝牙定位产品发展方向

随着市场需求的变化,精确的距离测量成为提升安全性和用户体验的重要因素。预计未来五年蓝牙位置服务设备的年均增长率为22%,到2028年出货量将达到5.63亿台。 为了满足这一需求,SIG即将在2024年下半年推出一项新功能——蓝牙信道探测(Bluetooth Channel Sounding)。这项新技术基于相位测量(PBR)和往返时间(RTT)两种测距方式,为蓝牙设备带来安全且精确的测距功能。

PointNet++改进策略 :模块改进 | PAConv,位置自适应卷积提升精度

题目:PAConv: Position Adaptive Convolution with Dynamic Kernel Assembling on Point Clouds来源:CVPR2021机构:香港大学论文:https://arxiv.org/abs/2103.14635代码:https://github.com/CVMI-Lab/PAConv 前言 PAConv,全称为位置自适应卷积

xcode6以上空模板配置

这边提供一份空模板配置的文档: http://pan.baidu.com/s/1dDxxg9j 1.确定安装的Xcode在应用程序中得绝对路径。 2.打开终端,使用cd指令,进入目录 AddMissingTemplates-master(要找到你存放AddMissingTemplates-master的路径),然后运行里面的脚本AddMissingTemplates.sh就ok了。(运行方

【matlab】双精度每字符占8字节,单精度每字符占4字节

>> help magicmagic - Magic squareThis MATLAB function returns an n-by-n matrix constructed from the integers 1through n^2 with equal row and column sums.M = magic(n)magic 的参考页另请参阅 ones, rand>> a = ma