b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器

2024-03-08 07:52

本文主要是介绍b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、损失函数
    • 1.简要介绍
    • 2.代码
  • 二、优化器
    • 1.简要介绍
    • 2.代码

一、损失函数

1.简要介绍

可参考博客:

常见的损失函数总结

损失函数的全面介绍

pytorch学习之十九种损失函数

损失函数(Loss Function)是用来衡量模型预测输出与实际标签之间的差异或误差程度的函数。在深度学习中,损失函数通常被设计为一个标量值,表示模型的预测值与真实标签之间的差异。

损失函数的选择对于训练深度学习模型非常重要,因为它直接影响着模型的训练效果和性能。在训练过程中,通过最小化损失函数来调整模型参数,使模型的预测结果逐渐接近真实标签,从而提高模型的准确性。

常见的损失函数:

均方误差(Mean Squared Error,MSE):用于回归任务,计算预测值与真实值之间的平方差的均值。

交叉熵损失函数(Cross Entropy Loss):用于分类任务,衡量模型输出的概率分布与真实标签的差异。

对数损失函数(Log Loss):也常用于二分类或多分类问题,衡量模型输出类别的概率与真实标签之间的关系。

Hinge损失函数:通常用于支持向量机(SVM)中,用于处理二分类问题。

Kullback-Leibler 散度(KL 散度):用于衡量两个概率分布之间的相似度。

2.代码

import torch
from torch import nn# 定义输入张量和目标张量
inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)# 对输入和目标张量进行reshape操作以匹配损失函数的输入要求
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))# 实例化 L1 损失函数
loss = nn.L1Loss()
# 计算 L1 损失值
result = loss(inputs, targets)
print(result)# 实例化均方误差(MSE)损失函数
loss_mse = nn.MSELoss()
# 计算均方误差损失值
result2 = loss_mse(inputs, targets)
print(result2)

代码运行结果:

在这里插入图片描述

二、优化器

1.简要介绍

优化器是深度学习中用于更新模型参数以最小化损失函数的算法。在神经网络训练过程中,通过计算损失函数对模型参数的梯度,优化器根据这些梯度来更新模型参数,使得损失函数逐渐减小,从而使模型更好地拟合训练数据。

2.代码

import torch.utils.data
import torchvision.datasets
from torch import nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# 加载 CIFAR-10 数据集
datasets = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
dataloader = DataLoader(datasets, batch_size=1)# 定义神经网络模型 Guodong
class Guodong(nn.Module):def __init__(self):super(Guodong, self).__init__()self.module1 = Sequential(Conv2d(3, 32, 5, padding=2),  # 输入通道数为3,输出通道数为32,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Conv2d(32, 32, 5, padding=2),  # 输入通道数为32,输出通道数为32,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Conv2d(32, 64, 5, padding=2),  # 输入通道数为32,输出通道数为64,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Flatten(),  # 将多维输入展平为一维Linear(1024, 64),  # 全连接层,输入维度为1024,输出维度为64Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10)def forward(self, input):output = self.module1(input)return output# 实例化 Guodong 模型
guodong = Guodong()# 定义交叉熵损失函数
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(guodong.parameters(), lr=0.01)
for epoch in range(20):loss_sum = 0.0# 遍历数据加载器中的数据for data in dataloader:imgs, target = data# 将图片输入模型得到预测输出outputs = guodong(imgs)# 计算交叉熵损失值result_loss = loss(outputs, target)optim.zero_grad()# 反向传播计算梯度result_loss.backward()optim.step()loss_sum += result_lossprint(loss_sum)


optim.zero_grad()
result_loss.backward()
optim.step()
这三处设置断点,调试,可以看到grad一开始是None,后来有了具体的数值

在这里插入图片描述
在这里插入图片描述
代码打印结果为:

在这里插入图片描述
(后面还没打印出来,程序运行有点慢QAQ)

可以看到最开始的时候loss_sum在变小,后来又变大。

在深度学习训练过程中,损失函数的值不一定是单调递减的,特别是在使用随机梯度下降(SGD)等基于随机采样的优化算法时。因此,损失函数值的变化可能会出现波动或不规则的情况。

sum_loss 的数值一开始是在减小的,但后来又增大了。这可能是由多种原因引起的,例如:

(1)训练数据的顺序:在每个 epoch 中,数据加载器可能以不同的顺序提供训练样本,这会导致模型参数的更新方向有所不同,从而影响损失函数的变化。

(2)学习率的设置:学习率控制着参数更新的步长大小,如果学习率设置得过大,可能会导致参数更新过程不稳定,损失函数值出现震荡或上升。

(3)模型复杂度和数据集的匹配程度:如果模型的复杂度过高,而训练数据集较小或难以拟合,模型可能会出现过拟合现象,导致损失函数值增大。

这篇关于b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/786489

相关文章

将sqlserver数据迁移到mysql的详细步骤记录

《将sqlserver数据迁移到mysql的详细步骤记录》:本文主要介绍将SQLServer数据迁移到MySQL的步骤,包括导出数据、转换数据格式和导入数据,通过示例和工具说明,帮助大家顺利完成... 目录前言一、导出SQL Server 数据二、转换数据格式为mysql兼容格式三、导入数据到MySQL数据

关于rpc长连接与短连接的思考记录

《关于rpc长连接与短连接的思考记录》文章总结了RPC项目中长连接和短连接的处理方式,包括RPC和HTTP的长连接与短连接的区别、TCP的保活机制、客户端与服务器的连接模式及其利弊分析,文章强调了在实... 目录rpc项目中的长连接与短连接的思考什么是rpc项目中的长连接和短连接与tcp和http的长连接短

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

C#使用HttpClient进行Post请求出现超时问题的解决及优化

《C#使用HttpClient进行Post请求出现超时问题的解决及优化》最近我的控制台程序发现有时候总是出现请求超时等问题,通常好几分钟最多只有3-4个请求,在使用apipost发现并发10个5分钟也... 目录优化结论单例HttpClient连接池耗尽和并发并发异步最终优化后优化结论我直接上优化结论吧,

Java内存泄漏问题的排查、优化与最佳实践

《Java内存泄漏问题的排查、优化与最佳实践》在Java开发中,内存泄漏是一个常见且令人头疼的问题,内存泄漏指的是程序在运行过程中,已经不再使用的对象没有被及时释放,从而导致内存占用不断增加,最终... 目录引言1. 什么是内存泄漏?常见的内存泄漏情况2. 如何排查 Java 中的内存泄漏?2.1 使用 J

Python MySQL如何通过Binlog获取变更记录恢复数据

《PythonMySQL如何通过Binlog获取变更记录恢复数据》本文介绍了如何使用Python和pymysqlreplication库通过MySQL的二进制日志(Binlog)获取数据库的变更记录... 目录python mysql通过Binlog获取变更记录恢复数据1.安装pymysqlreplicat

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

MySQL不使用子查询的原因及优化案例

《MySQL不使用子查询的原因及优化案例》对于mysql,不推荐使用子查询,效率太差,执行子查询时,MYSQL需要创建临时表,查询完毕后再删除这些临时表,所以,子查询的速度会受到一定的影响,本文给大家... 目录不推荐使用子查询和JOIN的原因解决方案优化案例案例1:查询所有有库存的商品信息案例2:使用EX

MySQL中my.ini文件的基础配置和优化配置方式

《MySQL中my.ini文件的基础配置和优化配置方式》文章讨论了数据库异步同步的优化思路,包括三个主要方面:幂等性、时序和延迟,作者还分享了MySQL配置文件的优化经验,并鼓励读者提供支持... 目录mysql my.ini文件的配置和优化配置优化思路MySQL配置文件优化总结MySQL my.ini文件

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日