本文主要是介绍【Python/Pytorch - 网络模型】-- TV Loss损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
文章目录
- 00 写在前面
- 01 基于Pytorch版本的TV Loss代码
- 02 论文下载
00 写在前面
在医学图像重建过程中,经常在代价方程中加入TV 正则项,该正则项作为去噪项,对于重建可以起到很大帮助作用。但是对于一些纹理细节要求较高的任务,加入TV 正则项,在一定程度上可能会降低纹理细节。
对于连续函数,其表达式为:
对于图片而言,即为离散的数值,求每一个像素和横向下一个像素的差的平方,加上纵向下一个像素的差的平方,再开β/2次根:
01 基于Pytorch版本的TV Loss代码
import torch
from torch.autograd import Variableclass TVLoss(torch.nn.Module):"""TV loss"""def __init__(self, weight=1):super(TVLoss, self).__init__()self.weight = weightdef forward(self, x):batch_size = x.size()[0]h_x = x.size()[2]w_x = x.size()[3]count_h = self._tensor_size(x[:, :, 1:, :])count_w = self._tensor_size(x[:, :, :, 1:])h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()return self.weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_sizedef _tensor_size(self, t):return t.size()[1] * t.size()[2] * t.size()[3]if __name__ == "__main__":x = Variable(torch.FloatTensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]).view(1, 2, 3, 3),requires_grad=True)tv = TVLoss()result = tv(x)print(result)
02 论文下载
Understanding Deep Image Representations by Inverting Them
这篇关于【Python/Pytorch - 网络模型】-- TV Loss损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!