M2m中的采样

2024-05-28 05:20
文章标签 采样 m2m

本文主要是介绍M2m中的采样,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 采样的完整代码

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSamplerdef get_oversampled_data(dataset, num_sample_per_class):""" Generate a list of indices that represents oversampling of the dataset. """targets = np.array(dataset.targets)class_sample_count = np.array([num_sample_per_class[target] for target in targets])weight = 1. / class_sample_countsamples_weight = torch.from_numpy(weight)sampler = WeightedRandomSampler(samples_weight, len(samples_weight))return samplerdef get_val_test_data(dataset, num_test_samples):""" Split dataset into validation and test indices. """num_classes = 10targets = dataset.targetstest_indices = []val_indices = []for i in range(num_classes):indices = [j for j, x in enumerate(targets) if x == i]np.random.shuffle(indices)val_indices.extend(indices[:num_test_samples])test_indices.extend(indices[num_test_samples:num_test_samples*2])return val_indices, test_indicesdef get_oversampled(dataset_name, num_sample_per_class, batch_size, transform_train, transform_test):""" Create training and testing loaders with oversampling for imbalance. """dataset_class = datasets.__dict__[dataset_presets[dataset_name]['class']]dataset_train = dataset_class(root='./data', train=True, download=True, transform=transform_train)dataset_test = dataset_class(root='./data', train=False, download=True, transform=transform_test)# Oversamplingsampler = get_oversampled_data(dataset_train, num_sample_per_class)train_loader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler)# Validation and Test splitval_idx, test_idx = get_val_test_data(dataset_test, 1000)val_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(val_idx))test_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(test_idx))return train_loader, val_loader, test_loader# Configuration and run
dataset_presets = {'cifar10': {'class': 'CIFAR10', 'num_classes': 10}
}
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
num_sample_per_class = [500] * 10  # Pretend we want equal class distributiontrain_loader, val_loader, test_loader = get_oversampled('cifar10', num_sample_per_class, 64, transform, transform)# Print out some info from loaders
for i, (inputs, targets) in enumerate(train_loader):print(f'Batch {i}, Targets Counts: {torch.bincount(targets)}')if i == 1:  # Just show first two batches for demonstrationbreak

WeightedRandomSampler类的__iter__

def __iter__(self) -> Iterator[int]:rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)return iter(rand_tensor.tolist())
  • 方法功能:此方法实现了迭代器协议,允许WeightedRandomSampler对象在迭代中返回一系列随机选择的索引。

过采样的效果

get_oversampled函数中,使用了WeightedRandomSampler来实现过采样的逻辑。这个过程虽然看起来是通过权重调整样本的选取概率,但实际上,通过这种方式也可以达到过采样的效果,尤其是当设置replacement=True时。让我们更详细地分析一下这一点:

权重的分配

权重是根据num_sample_per_class数组分配的,这个数组定义了每个类别希望被采样到的频率。在数据加载过程中,每个类别的样本将根据其在num_sample_per_class中对应的值获得一个权重。权重越大的类别在每次迭代中

被选中的概率也越大。这样,通过调整这些权重,我们可以控制模型在训练过程中看到的每个类别样本的频率,实现对类别不平衡的处理。

过采样的实现

在使用WeightedRandomSampler时,关键的参数是replacement

  • 如果replacement=True:这允许同一个样本在一次抽样中被多次选择,即进行了过采样。对于少数类的样本来说,即使它们在数据集中的绝对数量不多,也可以通过这种方式增加它们在每个训练批次中出现的次数,从而让模型更频繁地从这些少数类样本学习。

  • 如果replacement=False:则每个样本只能被抽样一次,这通常用于不放回的抽样。在这种模式下,WeightedRandomSampler不会直接导致过采样,但可以用来确保每个类别在数据批次中都有均等的代表性,从而帮助模型学习到更平衡的特征。

这篇关于M2m中的采样的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1009653

相关文章

重复采样魔法:用更多样本击败单次尝试的最强模型

这篇文章探讨了通过增加生成样本的数量来扩展大型语言模型(LLMs)在推理任务中的表现。 研究发现,重复采样可以显著提高模型的覆盖率,特别是在具有自动验证工具的任务中。研究还发现,覆盖率与样本数量之间的关系可以用指数幂律建模,揭示了推理时间的扩展规律。尽管多数投票和奖励模型在样本数量增加时趋于饱和,但在没有自动验证工具的任务中,识别正确样本仍然是一个重要的研究方向。 总体而言,重复采样提供了一种

研究纹理采样器在像素级别的采样位置

问题 【纹理采样器】是一个基础的概念。假设有一个正方形面片,顶点的UV范围是0.0~1.0,那么在这个正方形面片上采样一张纹理时,会呈现出完整的纹理。 但我现在关注的问题是,在像素级别上,采样的位置是怎样的。具体来讲:对于UV值是(0.0,0.0)的点,它对应的采样位置是纹理最左上角像素的中心?还是纹理最左上角像素的左上角?即,下面左右哪个是正确的情况? 在宏观上,尤其是像素较多的时候,二者

爆改YOLOv8|利用yolov10的SCDown改进yolov8-下采样

1, 本文介绍 YOLOv10 的 SCDown 方法来优化 YOLOv8 的下采样过程。SCDown 通过点卷积调整通道维度,再通过深度卷积进行空间下采样,从而减少了计算成本和参数数量。这种方法不仅降低了延迟,还在保持下采样过程信息的同时提供了竞争性的性能。 关于SCDown 的详细介绍可以看论文:https://arxiv.org/pdf/2405.14458 本文将讲解如何将SCDow

优化采样参数提升大语言模型响应质量:深入分析温度、top_p、top_k和min_p的随机解码策略

当向大语言模型(LLM)提出查询时,模型会为其词汇表中的每个可能标记输出概率值。从这个概率分布中采样一个标记后,我们可以将该标记附加到输入提示中,使LLM能够继续输出下一个标记的概率。这个采样过程可以通过诸如 temperature 和 top_p 等参数进行精确控制。但是你是否曾深入思考过temperature和top_p参数的具体作用? 本文将详细解析并可视化定义LLM输出行为的

word2vec 两个模型,两个加速方法 负采样加速Skip-gram模型 层序Softmax加速CBOW模型 item2vec 双塔模型 (DSSM双塔模型)

推荐领域(DSSM双塔模型): https://www.cnblogs.com/wilson0068/p/12881258.html   word2vec  word2vec笔记和实现 理解 Word2Vec 之 Skip-Gram 模型 上面这两个链接能让你彻底明白word2vec,不要搞什么公式,看完也是不知所云,也没说到本质. 目前用的比较多的都是Skip-gram模型 Go

YoloV10改进策略:下采样改进|自研下采样模块(独家改进)|疯狂涨点|附结构图

文章目录 摘要自研下采样模块及其变种第一种改进方法 YoloV10官方测试结果改进方法测试结果总结 摘要 本文介绍我自研的下采样模块。本次改进的下采样模块是一种通用的改进方法,你可以用分类任务的主干网络中,也可以用在分割和超分的任务中。已经有粉丝用来改进ConvNext模型,取得了非常好的效果,配合一些其他的改进,发一篇CVPR、ECCV之类的顶会完全没有问题。 本次我将这个模

CUDAPCL 点云体素下采样

文章目录 一、简介二、实现代码三、实现效果参考资料 一、简介 体素下采样是指使用常规体素网格从输入点云创建均匀下采样的点云。它经常被用作许多点云处理任务的预处理步骤。该算法分为两步操作: (1)并行的将每个点分配到其所处的体素中。 (2)并行遍历所有体素,并求取每个体素中所有点的质心点。 二、实现代码 VoxelSample.cuh #ifndef VOXELS

Open3D 体素随机下采样

目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 三、实现效果 3.1原始点云 3.2体素下采样后点云 Open3D点云算法汇总及实战案例汇总的目录地址: Open3D点云算法与点云深度学习案例汇总(长期更新)-CSDN博客 一、概述         体素随机下采样是一种常用的点云简化方法,通过将点云划分为立方体体素网格,并从每个体素中随机

matlab实现kaiser窗+时域采样序列(不管原信号拉伸成什么样子)是一样的,变到频谱后再采样就是一样的频域序列。

下图窗2的频谱在周期化的时候应该是2(w-k*pi/T)我直接对2w减得写错了 可见这两个kaiser窗频谱不一样,采样间隔为2T的窗,频谱压缩2倍,且以原采样频率的一半周期化。 但是这两个不同的kaiser窗在频域采样点的值使完全一致的。这和matlab模拟dft的过程吻合 也说明不管原信号拉伸成什么样子,只要时域采样序列是一样的,变到频谱后再采样就是一样的频域序列。 (与原信号的

如何通过更好的采样参数来提高 LLM 响应率

深入研究使用温度、top_p、top_k 和 min_p 进行随机解码 当你向大型语言模型 (LLM) 提出问题时,该模型会输出其词汇表中每个可能标记的概率。 从该概率分布中抽取一个标记后,我们可以将选定的标记附加到我们的输入提示中,以便 LLM 可以输出下一个标记的概率。 temperature该采样过程可以通过著名的和等参数来控制top_p。 在本文中,我将解释并直观地展示定义 LLM