_get_gt_mask、cat_mask、_get_other_mask

2024-09-06 23:36
文章标签 mask get cat gt

本文主要是介绍_get_gt_mask、cat_mask、_get_other_mask,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

import torch# 定义获取标签掩码的函数
def _get_gt_mask(logits, target):print("原始 logits:\n", logits)print("目标 target:\n", target)# 将 target 拉平为一维张量target = target.reshape(-1)print("拉平后的 target:\n", target)# 创建一个和 logits 大小相同的全零张量,然后根据 target 将对应的类别位置设置为1mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()print("生成的标签掩码 mask:\n", mask)# 返回根据 target 设置的标签掩码return mask# 定义组合掩码的函数
def cat_mask(t, mask1, mask2):print("输入张量 t:\n", t)print("标签掩码 mask1:\n", mask1)print("非标签掩码 mask2:\n", mask2)# 计算 mask1 对应的 t 值,sum(dim=1) 表示在类别维度上进行求和t1 = (t * mask1).sum(dim=1, keepdims=True)print("标签类别的加权和 t1:\n", t1)# 计算 mask2 对应的 t 值t2 = (t * mask2).sum(1, keepdims=True)print("非标签类别的加权和 t2:\n", t2)# 将两个值拼接成新的张量rt = torch.cat([t1, t2], dim=1)print("拼接后的结果 rt:\n", rt)return rt# 示例:假设有3个样本和5个类别的logits
logits = torch.tensor([[2.0, 1.0, 0.1, 3.0, 0.5],[1.0, 3.0, 2.5, 0.5, 0.3],[0.5, 2.2, 1.1, 4.0, 1.5]])# 对应的标签 target
target = torch.tensor([3, 1, 4])  # 每个样本的正确类别是3, 1, 4# 获取标签掩码
gt_mask = _get_gt_mask(logits, target)# 获取非标签掩码
def _get_other_mask(logits, target):print("原始 logits:\n", logits)print("目标 target:\n", target)# 将 target 拉平为一维张量target = target.reshape(-1)print("拉平后的 target:\n", target)# 创建一个和 logits 大小相同的全1张量,然后根据 target 将对应的类别位置设置为0mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()print("生成的非标签掩码 mask:\n", mask)return maskother_mask = _get_other_mask(logits, target)# 假设有某些 softmax 结果
t = torch.softmax(logits, dim=1)
print("Softmax 后的 logits (概率值):\n", t)# 使用标签掩码和非标签掩码进行组合
combined = cat_mask(t, gt_mask, other_mask)
print("最终组合后的结果:\n", combined)

 

这篇关于_get_gt_mask、cat_mask、_get_other_mask的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1143439

相关文章

SpringBoot中Get请求和POST请求接收参数示例详解

《SpringBoot中Get请求和POST请求接收参数示例详解》文章详细介绍了SpringBoot中Get请求和POST请求的参数接收方式,包括方法形参接收参数、实体类接收参数、HttpServle... 目录1、Get请求1.1 方法形参接收参数 这种方式一般适用参数比较少的情况,并且前后端参数名称必须

10 Source-Get-Post-JsonP 网络请求

划重点 使用vue-resource.js库 进行网络请求操作POST : this.$http.post ( … )GET : this.$http.get ( … ) 小鸡炖蘑菇 <!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-w

API28_OKgo_get注意事项

1: implementation 'com.lzy.net:okgo:2.1.4' 2:在BaseApplication中onCreate()中初始化initOKgo() private void initOKgo() {//---------这里给出的是示例代码,告诉你可以这么传,实际使用的时候,根据需要传,不需要就不传-------------//HttpHeaders headers

项目一(一) HttpClient中的POST请求和GET请求

HttpClient中的POST请求和GET请求 一、HttpClient简述 HttpClient是Apache Jakarta Common下的子项目,用来提供高效的、最新的、功能丰富的支持HTTP协议的客户端编程工具包,并且它支持HTTP协议最新的版本和建议。HttpClient已经应用在很多的项目中,比如Apache Jakarta上很著名的另外两个开源项目Cactus和HTMLU

apt-get update更新源时,出现“Hash Sum mismatch”问题

转载自:apt-get update更新源时,出现“Hash Sum mismatch”问题 当使用apt-get update更新源时,出现下面“Hash Sum mismatch”的报错,具体如下: root@localhost:~# apt-get update ...... ...... W: Failed to fetch http://us.archive.ubuntu.com/ub

Flutter-使用dio插件请求网络(get ,post,下载文件)

引入库:dio: ^2.1.13可直接运行的代码:包含了post,get 下载文件import 'package:flutter/material.dart';import 'package:dio/dio.dart';void main() {runApp(new MaterialApp(title: 'Container demo',home: new visitNetPage(),)

Flutter-加三方库卡在flutter package get 的解决办法

Windows PUB_HOSTED_URL ===== https://pub.flutter-io.cnFLUTTER_STORAGE_BASE_URL ===== https://storage.flutter-io.cn 增加两个环境变量,然后执行一下 flutter doctor命令。问题完美解决。

【tensorflow 使用错误】tensorflow2.0 过程中出现 Error : Failed to get convolution algorithm

如果在使用 tensorflow 过程中出现 Error : Failed to get convolution algorithm ,这是因为显卡内存被耗尽了。 解决办法: 在代码的开头加入如下两句,动态分配显存 physical_device = tf.config.experimental.list_physical_devices("GPU")tf.config.experiment

C#通过GET/POST方式发送Http请求

介绍http请求的两种方式,get和post方式。并用C#语言实现,如何请求url并获取返回的数据 两者的区别: 参数 Get请求把提交的数据进行简单编码,同时将url的一部分发送到服务器   比如url:Http://127.0.0.1/login.jsp?Name=zhangshi&Age=30&Submit=%cc%E+%BD%BB   所以get请求方式提交的数据存在一定的安全隐患

使用 GET 方法读取表单数据

HelloForm源码: package firstweb;// 导入必需的 java 库import java.io.*;import javax.servlet.*;import javax.servlet.http.*;// 扩展 HttpServlet 类public class HelloForm extends HttpServlet {public void doGet