Pytorch 入门之-- 逻辑回归 Logistic_Regression

2024-05-06 22:32

本文主要是介绍Pytorch 入门之-- 逻辑回归 Logistic_Regression,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

代码+注释

__author__ = 'SherlockLiao'import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import time
# 定义超参数
batch_size = 32
learning_rate = 1e-3
num_epoches = 100# 下载训练集 MNIST 手写数字训练集,不需要下载的话download=FALSE
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=TRUE)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 定义 Logistic Regression 模型
# pytorch的基本框架init forward class Logstic_Regression(nn.Module):def __init__(self, in_dim, n_class):super(Logstic_Regression, self).__init__()self.logstic = nn.Linear(in_dim, n_class)def forward(self, x):out = self.logstic(x)return out#模型实例化
model = Logstic_Regression(28 * 28, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:model = model.cuda()# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(num_epoches):print('*' * 10)print('epoch {}'.format(epoch + 1))since = time.time()running_loss = 0.0running_acc = 0.0for i, data in enumerate(train_loader, 1):#迭代数据img, label = dataimg = img.view(img.size(0), -1)  # 将图片展开成 28x28if use_gpu:img = Variable(img).cuda()label = Variable(label).cuda()else:img = Variable(img)label = Variable(label)# 向前传播out = model(img)loss = criterion(out, label)#解决方法:把        train_loss += loss.data[0]        修改为        train_loss += loss.item()running_loss += loss.item()* label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()running_acc += num_correct.item()# 向后传播optimizer.zero_grad()loss.backward()optimizer.step()if i % 300 == 0:print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, num_epoches, running_loss / (batch_size * i),running_acc / (batch_size * i)))
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset))))model.eval()
eval_loss = 0.
eval_acc = 0.
for data in test_loader:img, label = dataimg = img.view(img.size(0), -1)if use_gpu:img = Variable(img, volatile=True).cuda()label = Variable(label, volatile=True).cuda()else:img = Variable(img, volatile=True)label = Variable(label, volatile=True)out = model(img)loss = criterion(out, label)eval_loss += loss.item() * label.size(0)_, pred = torch.max(out, 1)num_correct = (pred == label).sum()eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset))))
print('Time:{:.1f} s'.format(time.time() - since))
print()# 保存模型
torch.save(model.state_dict(), './logstic.pth')

 

这篇关于Pytorch 入门之-- 逻辑回归 Logistic_Regression的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/965558

相关文章

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

Java逻辑运算符之&&、|| 与&、 |的区别及应用

《Java逻辑运算符之&&、||与&、|的区别及应用》:本文主要介绍Java逻辑运算符之&&、||与&、|的区别及应用的相关资料,分别是&&、||与&、|,并探讨了它们在不同应用场景中... 目录前言一、基本概念与运算符介绍二、短路与与非短路与:&& 与 & 的区别1. &&:短路与(AND)2. &:非短

Python FastAPI入门安装使用

《PythonFastAPI入门安装使用》FastAPI是一个现代、快速的PythonWeb框架,用于构建API,它基于Python3.6+的类型提示特性,使得代码更加简洁且易于绶护,这篇文章主要介... 目录第一节:FastAPI入门一、FastAPI框架介绍什么是ASGI服务(WSGI)二、FastAP