【笔记】scatter_函数:用法如 torch.zeros(target.size(0), 2).scatter_(1,target,1).to(self.device)

2024-01-29 06:40

本文主要是介绍【笔记】scatter_函数:用法如 torch.zeros(target.size(0), 2).scatter_(1,target,1).to(self.device),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

target内容:

tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
        0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 0])

类型:

<class 'torch.Tensor'>

target.shape

torch.Size([100])

target.size()

torch.Size([100])

程序1:error 

要知道错误的原因

RuntimeError: Expected index [1, 100] to be smaller than self [100, 2] apart from dimension 1 and to be smaller size than src [100, 2]
 

程序2: true

import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoaderclass mydataset(Dataset):def __init__(self, path):self.path = pathself.dataset = os.listdir(self.path)self.mean = [0.4878, 0.4545, 0.4168]self.std = [0.2623, 0.2555, 0.2577]def __getitem__(self, index):name = self.dataset[index]name_list = name.split(".")target = int(name_list[0])target = torch.tensor(target)img = Image.open(os.path.join(self.path, name))img = np.array(img) / 255# 去均值img = (img - self.mean) / self.std# img 是 float64data = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)return data, targetdef __len__(self):return len(self.dataset)class mynetwork(nn.Module):def __init__(self):super(mynetwork, self).__init__()# 有序容器self.line1 = nn.Sequential(nn.Linear(3 * 100 * 100, 5120),nn.ReLU(),nn.Linear(5120, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 2560),nn.ReLU(),nn.Linear(2560, 512),nn.ReLU(),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 2),)#  parse  vt. 解析;从语法上分析def forward(self, parse):data = torch.reshape(parse, shape=(-1, 3 * 100 * 100))return self.line1(data)class train(object):def __init__(self, path):self.path = pathself.test_dataset = mydataset(self.path)self.train_dataset = mydataset(self.path)self.criterion = torch.nn.MSELoss()self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.net = mynetwork().to(self.device)self.optimize = torch.optim.Adam(self.net.parameters())def dataloader(self, batch):train_data_loader = DataLoader(dataset=self.train_dataset, batch_size=batch, shuffle=True)test_data_loader = DataLoader(dataset=self.test_dataset, batch_size=batch, shuffle=True)return train_data_loader, test_data_loaderdef trainnet(self, batch, epoch):train_data_loader, test_data_loader = self.dataloader(batch)losses = []accuracy = []for i in range(epoch):for j, (input, target) in enumerate(train_data_loader):input = input.to(self.device)output = self.net(input)print(target,type(target),target.shape,target.size())target = torch.zeros(target.size(0), 2).scatter_(1,target.view(1,-1),1).to(self.device)print(target,type(target),target.shape,target.size())print(target)input()if __name__ == "__main__":path = r"./cat_dog/img"t = train(path)t.trainnet(100, 10)

输出:

/home/wangbin/anaconda3/envs/deep_learning/bin/python3.7 /media/wangbin/F/深度学习_程序/dog_cat/cat_dog.py
tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0,0, 1, 1, 0]) <class 'torch.Tensor'> torch.Size([100]) torch.Size([100])
tensor([[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.]], device='cuda:0') <class 'torch.Tensor'> torch.Size([100, 2]) torch.Size([100, 2])
tensor([[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[0., 1.],[0., 1.],[1., 0.],[0., 1.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[1., 0.],[0., 1.],[0., 1.],[1., 0.]], device='cuda:0')

附:

 

函数资料:

torch._C._TensorBase._TensorBase def scatter_(self,
             dim: int,
             index: Any,
             src: Any,
             reduce: str = None) -> None
scatter_(dim, index, src, reduce=None) -> Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
This is the reverse operation of the manner described in ~Tensor.gather.
self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.
Moreover, as for ~Tensor.gather, the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.
Additionally accepts an optional reduce argument that allows specification of an optional reduction operation, which is applied to all values in the tensor src into self at the indicies specified in the index. For each value in src, the reduction operation is applied to an index in self which is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
Given a 3-D tensor and reduction using the multiplication operation, self is updated as:
self[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2
Reducing with the addition operation is the same as using ~torch.Tensor.scatter_add_.
Note
Reduction is not yet implemented for the CUDA backend.
Example:
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2300]])

>>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
>>> z
tensor([[1.0000, 1.0000, 1.2300, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.2300]])
Params:
dim – the axis along which to index
index – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
src – the source element(s) to scatter, incase `value` is not specified
reduce – reduction operation to apply, can be either 'add' or 'multiply'.
  < Python 3.7 (deep_learning) >

这篇关于【笔记】scatter_函数:用法如 torch.zeros(target.size(0), 2).scatter_(1,target,1).to(self.device)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/656061

相关文章

Java中的数组与集合基本用法详解

《Java中的数组与集合基本用法详解》本文介绍了Java数组和集合框架的基础知识,数组部分涵盖了一维、二维及多维数组的声明、初始化、访问与遍历方法,以及Arrays类的常用操作,对Java数组与集合相... 目录一、Java数组基础1.1 数组结构概述1.2 一维数组1.2.1 声明与初始化1.2.2 访问

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python中你不知道的gzip高级用法分享

《Python中你不知道的gzip高级用法分享》在当今大数据时代,数据存储和传输成本已成为每个开发者必须考虑的问题,Python内置的gzip模块提供了一种简单高效的解决方案,下面小编就来和大家详细讲... 目录前言:为什么数据压缩如此重要1. gzip 模块基础介绍2. 基本压缩与解压缩操作2.1 压缩文

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos

解读GC日志中的各项指标用法

《解读GC日志中的各项指标用法》:本文主要介绍GC日志中的各项指标用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基础 GC 日志格式(以 G1 为例)1. Minor GC 日志2. Full GC 日志二、关键指标解析1. GC 类型与触发原因2. 堆

MySQL数据库中ENUM的用法是什么详解

《MySQL数据库中ENUM的用法是什么详解》ENUM是一个字符串对象,用于指定一组预定义的值,并可在创建表时使用,下面:本文主要介绍MySQL数据库中ENUM的用法是什么的相关资料,文中通过代码... 目录mysql 中 ENUM 的用法一、ENUM 的定义与语法二、ENUM 的特点三、ENUM 的用法1

JavaSE正则表达式用法总结大全

《JavaSE正则表达式用法总结大全》正则表达式就是由一些特定的字符组成,代表的是一个规则,:本文主要介绍JavaSE正则表达式用法的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录常用的正则表达式匹配符正则表China编程达式常用的类Pattern类Matcher类PatternSynta

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

MySQL之InnoDB存储引擎中的索引用法及说明

《MySQL之InnoDB存储引擎中的索引用法及说明》:本文主要介绍MySQL之InnoDB存储引擎中的索引用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录1、背景2、准备3、正篇【1】存储用户记录的数据页【2】存储目录项记录的数据页【3】聚簇索引【4】二