本文主要是介绍TMC阅读笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
参考论文及链接:
论文:Trusted Multi-View Classification
解读链接
1 参数介绍
对模型,需要输入分类的类别个数,由于是多模态分类,要输入视图的个数,以及每个视图的维度数。
def __init__(self, classes, views, classifier_dims, lambda_epochs=1):""":param classes: Number of classification categories:param views: Number of views:param classifier_dims: Dimension of the classifier:param annealing_epoch: KL divergence annealing epoch during training"""super(TMC, self).__init__()self.views = viewsself.classes = classesself.lambda_epochs = lambda_epochsself.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.views)])
对于代码中的Classifier类,需要对如下参数初始化:
class Classifier(nn.Module):def __init__(self, classifier_dims, classes):super(Classifier, self).__init__()self.num_layers = len(classifier_dims)self.fc = nn.ModuleList()for i in range(self.num_layers-1):self.fc.append(nn.Linear(classifier_dims[i], classifier_dims[i+1]))self.fc.append(nn.Linear(classifier_dims[self.num_layers-1], classes))self.fc.append(nn.Softplus())def forward(self, x):h = self.fc[0](x)for i in range(1, len(self.fc)):h = self.fc[i](h)return h
2 前向传递流程
前向传递流程如下:首先各输入从Classifier类前向传递的神经网络中学习证据,最后经过softplus层,得到evidence。每个evidence存储在字典中。
def infer(self, input):""":param input: Multi-view data:return: evidence of every view"""evidence = dict()for v_num in range(self.views):evidence[v_num] = self.Classifiers[v_num](input[v_num])return evidence
之后根据下图公式,求出u,b,S的值。在使用DS组合理论来
代码如下:
2.1 公式求值
后面的代码还计算了损失。
def ce_loss(p, alpha, c, global_step, annealing_step):S = torch.sum(alpha, dim=1, keepdim=True)E = alpha - 1label = F.one_hot(p, num_classes=c)A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)annealing_coef = min(1, global_step / annealing_step)alp = E * (1 - label) + 1B = annealing_coef * KL(alp, c)return (A + B)
2.2 DS组合理论
def DS_Combin(self, alpha):""":param alpha: All Dirichlet distribution parameters.:return: Combined Dirichlet distribution parameters."""def DS_Combin_two(alpha1, alpha2):""":param alpha1: Dirichlet distribution parameters of view 1:param alpha2: Dirichlet distribution parameters of view 2:return: Combined Dirichlet distribution parameters"""alpha = dict()alpha[0], alpha[1] = alpha1, alpha2b, S, E, u = dict(), dict(), dict(), dict()for v in range(2):S[v] = torch.sum(alpha[v], dim=1, keepdim=True)E[v] = alpha[v]-1b[v] = E[v]/(S[v].expand(E[v].shape))u[v] = self.classes/S[v]# b^0 @ b^(0+1)bb = torch.bmm(b[0].view(-1, self.classes, 1), b[1].view(-1, 1, self.classes))# b^0 * u^1uv1_expand = u[1].expand(b[0].shape)bu = torch.mul(b[0], uv1_expand)# b^1 * u^0uv_expand = u[0].expand(b[0].shape)ub = torch.mul(b[1], uv_expand)# calculate Cbb_sum = torch.sum(bb, dim=(1, 2), out=None)bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1)C = bb_sum - bb_diag# calculate b^ab_a = (torch.mul(b[0], b[1]) + bu + ub)/((1-C).view(-1, 1).expand(b[0].shape))# calculate u^au_a = torch.mul(u[0], u[1])/((1-C).view(-1, 1).expand(u[0].shape))# calculate new SS_a = self.classes / u_a# calculate new e_ke_a = torch.mul(b_a, S_a.expand(b_a.shape))alpha_a = e_a + 1return alpha_afor v in range(len(alpha)-1):if v==0:alpha_a = DS_Combin_two(alpha[0], alpha[1])else:alpha_a = DS_Combin_two(alpha_a, alpha[v+1])return alpha_a
3 通过学习获得M
这篇关于TMC阅读笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!