莫凡Pytorch学习笔记(二)

2023-12-11 20:10
文章标签 学习 笔记 pytorch 莫凡

本文主要是介绍莫凡Pytorch学习笔记(二),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Pytorch回归模型搭建

本篇笔记主要对应于莫凡Pytorch中的3.1节。主要讲了如何使用Pytorch搭建一个回归模型的神经网络。

在Pytorch中自定义一个神经网络时,我们需要继承torch.nn.Module来书写自己的神经网络。在继承该类时,必须重新实现__init__构造函数和forward这两个方法。这里有一些注意点:

  1. 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;

  2. 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替;

  3. forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

接下来我们来自己搭建一个回归模型的神经网络。

数据生成与展示

这里生成一组 y = x 2 y=x^2 y=x2的数据,并加入一些随机噪声。

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)   # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())                  # noisy y data (tensor), shape=(100, 1)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

在这里插入图片描述

基本网络搭建

我们自定义一个类来完成回归操作

class Net(torch.nn.Module):def __init__(self, n_feature, n_hidden, n_output):# 分别表示feature个数、隐藏层神经元数个数、输出值数目super(Net, self).__init__()self.hidden = torch.nn.Linear(n_feature, n_hidden)self.predict = torch.nn.Linear(n_hidden, n_output)def forward(self, x):# x 是输入数据x = F.relu(self.hidden(x))y = self.predict(x)return y

这是一个两层的神经网络,其包含一个隐藏层即self.hidden,之后便连接一个输出层self.predict。在前向传播时,网络对隐层的输出进行了Relu操作。

网络搭建完成后,我们可以打印输出一下这个网络的基本结构

net = Net(1, 10, 1)
print(net)

得到输出如下

Net((hidden): Linear(in_features=1, out_features=10, bias=True)(predict): Linear(in_features=10, out_features=1, bias=True)
)

设置优化器和损失函数

接下来我们设置网络的优化器和损失函数。
优化方法设置为随机梯度下降法,学习率设置为0.5。
一般回归问题使用最小均方误差作为损失函数。

optimizer = torch.optim.SGD(net.parameters(), lr=0.5)  
loss_func = torch.nn.MSELoss()    # 回归问题采用MSE

训练与展示

最后,我们展示输出并可视化中间过程。

plt.ion()
for step in range(100):prediction = net(x)loss = loss_func(prediction, y)optimizer.zero_grad()   # 首先将所有参数的梯度降为0(因为每次计算梯度后这个值都会保留,不清零就会导致不正确)loss.backward()         # 进行反向传递,计算出计算图中所有节点的梯度optimizer.step()        # 计算完成后,使用optimizer优化这些梯度if step % 20 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data)plt.savefig("./img/02_"+str(step)+".png")plt.pause(0.1)plt.ioff()
plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

可以看到随着训练的进行,loss逐渐降低,模型拟合的效果越来越好。

参考

  1. 莫凡Python:Pytorch动态神经网络,https://mofanpy.com/tutorials/machine-learning/torch/
  2. pytorch教程之nn.Module类详解——使用Module类来自定义模型,https://blog.csdn.net/qq_27825451/article/details/90550890

这篇关于莫凡Pytorch学习笔记(二)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/481961

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

Java进阶学习之如何开启远程调式

《Java进阶学习之如何开启远程调式》Java开发中的远程调试是一项至关重要的技能,特别是在处理生产环境的问题或者协作开发时,:本文主要介绍Java进阶学习之如何开启远程调式的相关资料,需要的朋友... 目录概述Java远程调试的开启与底层原理开启Java远程调试底层原理JVM参数总结&nbsMbKKXJx

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用