本文主要是介绍Influence-Balanced Loss 中的Resample策略,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
改进的sampler策略
elif args.train_rule == 'Resample':train_sampler = ImbalancedDatasetSampler(dset_train)
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):def __init__(self, dataset, indices=None, num_samples=None):# if indices is not provided, # all elements in the dataset will be consideredself.indices = list(range(len(dataset))) \if indices is None else indices# if num_samples is not provided, # draw `len(indices)` samples in each iterationself.num_samples = len(self.indices) \if num_samples is None else num_samples # 数据集样本个数# distribution of classes in the dataset label_to_count = [0] * len(np.unique(dataset.targets))for idx in self.indices:label = self._get_label(dataset, idx)label_to_count[label] += 1beta = 0.9999effective_num = 1.0 - np.power(beta, label_to_count)per_cls_weights = (1.0 - beta) / np.array(effective_num) #各类别的权重 per_cls_weights: [0.00248924 0.00202661 0.00689909 0.00975834]# weight for each sampleweights = [per_cls_weights[self._get_label(dataset, idx)]for idx in self.indices] # 各样本的权重self.weights = torch.DoubleTensor(weights)def _get_label(self, dataset, idx):return dataset.targets[idx]def __iter__(self):return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist())def __len__(self):return self.num_samples
Class Counts: [410, 506, 146, 103]
per_cls_weights: [0.00248924 0.00202661 0.00689909 0.00975834]0.00248924*410+0.00202661*506+0.00689909*146+103*0.00975834=4.05842922
普通sampler
继承了sampler类,然后重新为数据集中的各样本分配权重。
如果使用的是普通的采样器(sampler),例如 PyTorch 中的 RandomSampler
或简单的顺序采样,每个样本通常被赋予等权重。这意味着在抽样过程中,每个样本被选中的概率是相等的。
在这种情况下,假设数据集中有 𝑁个样本,那么每个样本被选中的概率和权重都是 1/𝑁。这种方式不考虑数据集中可能存在的类别不平衡问题,每个样本被选取的机会完全相同。
例如,如果你有一个包含 100 个样本的数据集,并使用普通的采样器进行随机抽样,则每个样本被选中的概率都是 1%。这种采样方式简单且常用,但在处理类别极度不平衡的数据集时可能不够有效,因为它可能导致模型对多数类过拟合,而忽视了少数类。
ImbalancedDatasetSampler的采样策略的公式和CBReweight的公式差不多
两者都试图通过为每个类别的样本分配不同的权重来解决类别不平衡问题,但应用的场景和具体实现有所不同:
- ImbalancedDatasetSampler:影响的是数据采样过程,通过改变数据输入模型的方式来达成类别平衡。
- CBReweight:直接作用于模型的损失函数,通过改变损失计算方式来强调少数类的重要性。
尽管两者策略相似,但具体实现和影响的环节(数据层面 vs. 模型训练层面)有所区别。
ImbalancedDatasetSampler最后会将整个数据集的每个样本的权重列表送入官方写好的sampler里(继承普通的sampler类),CBReweight会将每个类的权重列表送入官方写好的代码里(交叉熵损失)
这篇关于Influence-Balanced Loss 中的Resample策略的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!