b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器

2024-03-08 07:52

本文主要是介绍b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、损失函数
    • 1.简要介绍
    • 2.代码
  • 二、优化器
    • 1.简要介绍
    • 2.代码

一、损失函数

1.简要介绍

可参考博客:

常见的损失函数总结

损失函数的全面介绍

pytorch学习之十九种损失函数

损失函数(Loss Function)是用来衡量模型预测输出与实际标签之间的差异或误差程度的函数。在深度学习中,损失函数通常被设计为一个标量值,表示模型的预测值与真实标签之间的差异。

损失函数的选择对于训练深度学习模型非常重要,因为它直接影响着模型的训练效果和性能。在训练过程中,通过最小化损失函数来调整模型参数,使模型的预测结果逐渐接近真实标签,从而提高模型的准确性。

常见的损失函数:

均方误差(Mean Squared Error,MSE):用于回归任务,计算预测值与真实值之间的平方差的均值。

交叉熵损失函数(Cross Entropy Loss):用于分类任务,衡量模型输出的概率分布与真实标签的差异。

对数损失函数(Log Loss):也常用于二分类或多分类问题,衡量模型输出类别的概率与真实标签之间的关系。

Hinge损失函数:通常用于支持向量机(SVM)中,用于处理二分类问题。

Kullback-Leibler 散度(KL 散度):用于衡量两个概率分布之间的相似度。

2.代码

import torch
from torch import nn# 定义输入张量和目标张量
inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)# 对输入和目标张量进行reshape操作以匹配损失函数的输入要求
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))# 实例化 L1 损失函数
loss = nn.L1Loss()
# 计算 L1 损失值
result = loss(inputs, targets)
print(result)# 实例化均方误差(MSE)损失函数
loss_mse = nn.MSELoss()
# 计算均方误差损失值
result2 = loss_mse(inputs, targets)
print(result2)

代码运行结果:

在这里插入图片描述

二、优化器

1.简要介绍

优化器是深度学习中用于更新模型参数以最小化损失函数的算法。在神经网络训练过程中,通过计算损失函数对模型参数的梯度,优化器根据这些梯度来更新模型参数,使得损失函数逐渐减小,从而使模型更好地拟合训练数据。

2.代码

import torch.utils.data
import torchvision.datasets
from torch import nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader# 加载 CIFAR-10 数据集
datasets = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 创建数据加载器
dataloader = DataLoader(datasets, batch_size=1)# 定义神经网络模型 Guodong
class Guodong(nn.Module):def __init__(self):super(Guodong, self).__init__()self.module1 = Sequential(Conv2d(3, 32, 5, padding=2),  # 输入通道数为3,输出通道数为32,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Conv2d(32, 32, 5, padding=2),  # 输入通道数为32,输出通道数为32,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Conv2d(32, 64, 5, padding=2),  # 输入通道数为32,输出通道数为64,卷积核大小为5,填充为2MaxPool2d(2),  # 最大池化层,核大小为2Flatten(),  # 将多维输入展平为一维Linear(1024, 64),  # 全连接层,输入维度为1024,输出维度为64Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10)def forward(self, input):output = self.module1(input)return output# 实例化 Guodong 模型
guodong = Guodong()# 定义交叉熵损失函数
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(guodong.parameters(), lr=0.01)
for epoch in range(20):loss_sum = 0.0# 遍历数据加载器中的数据for data in dataloader:imgs, target = data# 将图片输入模型得到预测输出outputs = guodong(imgs)# 计算交叉熵损失值result_loss = loss(outputs, target)optim.zero_grad()# 反向传播计算梯度result_loss.backward()optim.step()loss_sum += result_lossprint(loss_sum)


optim.zero_grad()
result_loss.backward()
optim.step()
这三处设置断点,调试,可以看到grad一开始是None,后来有了具体的数值

在这里插入图片描述
在这里插入图片描述
代码打印结果为:

在这里插入图片描述
(后面还没打印出来,程序运行有点慢QAQ)

可以看到最开始的时候loss_sum在变小,后来又变大。

在深度学习训练过程中,损失函数的值不一定是单调递减的,特别是在使用随机梯度下降(SGD)等基于随机采样的优化算法时。因此,损失函数值的变化可能会出现波动或不规则的情况。

sum_loss 的数值一开始是在减小的,但后来又增大了。这可能是由多种原因引起的,例如:

(1)训练数据的顺序:在每个 epoch 中,数据加载器可能以不同的顺序提供训练样本,这会导致模型参数的更新方向有所不同,从而影响损失函数的变化。

(2)学习率的设置:学习率控制着参数更新的步长大小,如果学习率设置得过大,可能会导致参数更新过程不稳定,损失函数值出现震荡或上升。

(3)模型复杂度和数据集的匹配程度:如果模型的复杂度过高,而训练数据集较小或难以拟合,模型可能会出现过拟合现象,导致损失函数值增大。

这篇关于b站小土堆pytorch学习记录—— P23-P24 损失函数、反向传播和优化器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/786489

相关文章

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

GO语言中函数命名返回值的使用

《GO语言中函数命名返回值的使用》在Go语言中,函数可以为其返回值指定名称,这被称为命名返回值或命名返回参数,这种特性可以使代码更清晰,特别是在返回多个值时,感兴趣的可以了解一下... 目录基本语法函数命名返回特点代码示例命名特点基本语法func functionName(parameters) (nam

Python实战之SEO优化自动化工具开发指南

《Python实战之SEO优化自动化工具开发指南》在数字化营销时代,搜索引擎优化(SEO)已成为网站获取流量的重要手段,本文将带您使用Python开发一套完整的SEO自动化工具,需要的可以了解下... 目录前言项目概述技术栈选择核心模块实现1. 关键词研究模块2. 网站技术seo检测模块3. 内容优化分析模

Python Counter 函数使用案例

《PythonCounter函数使用案例》Counter是collections模块中的一个类,专门用于对可迭代对象中的元素进行计数,接下来通过本文给大家介绍PythonCounter函数使用案例... 目录一、Counter函数概述二、基本使用案例(一)列表元素计数(二)字符串字符计数(三)元组计数三、C

Java实现复杂查询优化的7个技巧小结

《Java实现复杂查询优化的7个技巧小结》在Java项目中,复杂查询是开发者面临的“硬骨头”,本文将通过7个实战技巧,结合代码示例和性能对比,手把手教你如何让复杂查询变得优雅,大家可以根据需求进行选择... 目录一、复杂查询的痛点:为何你的代码“又臭又长”1.1冗余变量与中间状态1.2重复查询与性能陷阱1.

Python内存优化的实战技巧分享

《Python内存优化的实战技巧分享》Python作为一门解释型语言,虽然在开发效率上有着显著优势,但在执行效率方面往往被诟病,然而,通过合理的内存优化策略,我们可以让Python程序的运行速度提升3... 目录前言python内存管理机制引用计数机制垃圾回收机制内存泄漏的常见原因1. 循环引用2. 全局变

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Python中的filter() 函数的工作原理及应用技巧

《Python中的filter()函数的工作原理及应用技巧》Python的filter()函数用于筛选序列元素,返回迭代器,适合函数式编程,相比列表推导式,内存更优,尤其适用于大数据集,结合lamb... 目录前言一、基本概念基本语法二、使用方式1. 使用 lambda 函数2. 使用普通函数3. 使用 N

MySQL中REPLACE函数与语句举例详解

《MySQL中REPLACE函数与语句举例详解》在MySQL中REPLACE函数是一个用于处理字符串的强大工具,它的主要功能是替换字符串中的某些子字符串,:本文主要介绍MySQL中REPLACE函... 目录一、REPLACE()函数语法:参数说明:功能说明:示例:二、REPLACE INTO语句语法:参数