本文主要是介绍Deep Coral loss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
import torchdef CORAL(source, target):d = source.data.shape[1] #coral公式中的分母部分ns, nt = source.data.shape[0], target.data.shape[0]# source covariancexm = torch.mean(source, 0, keepdim=True) - source #对应着Cs的分子部分xc = xm.t() @ xm/(ns-1) #对应着Cs的分子部分# target covariancexmt = torch.mean(target, 0, keepdim=True) - target#对应着Ct的分子部分xct = xmt.t() @ xmt/(nt-1)#对应着Ct的分子部分# frobenius norm between source and targetloss = torch.mean(torch.mul((xc - xct), (xc - xct))) #Cs-Ct的点乘loss = loss/(4*d*d)return loss
Coral公式:
只做学习使用,作者也是看了别人进行了学习总结,希望能对你有所帮助。
故障诊断与python学习
Deep CORAL: Correlation Alignment for Deep Domain Adaptation_gdtop818的博客-CSDN博客
这篇关于Deep Coral loss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!