pytorch | with torch.no_grad()

2024-03-31 02:36
文章标签 pytorch torch grad

本文主要是介绍pytorch | with torch.no_grad(),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1.关于with

with 是python中上下文管理器,简单理解,当要进行固定的进入,返回操作时,可以将对应需要的操作,放在with所需要的语句中。比如文件的写入(需要打开关闭文件)等。

以下为一个文件写入使用with的例子。

with open (filename,'w') as sh:    sh.write("#!/bin/bash\n")sh.write("#$ -N "+'IC'+altas+str(patientNumber)+altas+'\n')sh.write("#$ -o "+pathSh+altas+'log.log\n') sh.write("#$ -e "+pathSh+altas+'err.log\n') sh.write('source ~/.bashrc\n')          sh.write('. "/home/kjsun/anaconda3/etc/profile.d/conda.sh"\n')sh.write('conda activate python27\n')sh.write('echo "to python"\n')sh.write('echo "finish"\n')sh.close()

with 后部分,可以将 with 后的语句运行,将其返回结果给到 as 后的变量(sh),之后的代码块对 close 进行操作。

2.关于with torch.no_grad():

在使用 pytorch 时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。而对于 tensor 的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。

以下分别为使用和不使用的情况:

(1)使用with torch.no_grad():

with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))        
print(outputs)

运行结果:

Accuracy of the network on the 10000 test images: 55 %

tensor([[-2.9141, -3.8210, 2.1426, 3.0883, 2.6363, 2.6878, 2.8766, 0.3396,

-4.7505, -3.8502],

[-1.4012, -4.5747, 1.8557, 3.8178, 1.1430, 3.9522, -0.4563, 1.2740,

-3.7763, -3.3633],

[ 1.3090, 0.1812, 0.4852, 0.1315, 0.5297, -0.3215, -2.0045, 1.0426,

-3.2699, -0.5084],

[-0.5357, -1.9851, -0.2835, -0.3110, 2.6453, 0.7452, -1.4148, 5.6919,

-6.3235, -1.6220]])

此时的 outputs 没有属性。

而对应的不使用的情况

for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
print(outputs)

结果如下:

Accuracy of the network on the 10000 test images: 55 %

tensor([[-2.9141, -3.8210, 2.1426, 3.0883, 2.6363, 2.6878, 2.8766, 0.3396,

-4.7505, -3.8502],

[-1.4012, -4.5747, 1.8557, 3.8178, 1.1430, 3.9522, -0.4563, 1.2740,

-3.7763, -3.3633],

[ 1.3090, 0.1812, 0.4852, 0.1315, 0.5297, -0.3215, -2.0045, 1.0426,

-3.2699, -0.5084],

[-0.5357, -1.9851, -0.2835, -0.3110, 2.6453, 0.7452, -1.4148, 5.6919,

-6.3235, -1.6220]], grad_fn=<AddmmBackward>)

可以看到,此时有 grad_fn= 属性,表示,计算的结果在一计算图当中,可以进行梯度反传等操作。但是,两者计算的结果实际上是没有区别的。

这篇关于pytorch | with torch.no_grad()的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/863199

相关文章

如何在pycharm安装torch包

《如何在pycharm安装torch包》:本文主要介绍如何在pycharm安装torch包方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录在pycharm安装torch包适http://www.chinasem.cn配于我电脑的指令为适用的torch包为总结在p

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用