SRGAN loss部分的mindspore代码实现

2024-01-07 16:38

本文主要是介绍SRGAN loss部分的mindspore代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

转载地址:https://bbs.huaweicloud.com/forum/thread-137103-1-1.html

作者: 雨丝儿

pytorch代码实现:https://bbs.huaweicloud.com/forum/forum.php?mod=viewthread&tid=137101

接上篇帖子 ,本贴分享mindspore代码的实现:

import mindspore.common.dtype as mstype

import mindspore.nn as nn

import mindspore.ops as ops

from src.vgg19.define import vgg19

from src.loss.Meanshift import MeanShift

对于Discriminator:

class DiscriminatorLoss(nn.Cell):

    """Loss for discriminator"""

    def __init__(self, discriminator, generator):

        super(DiscriminatorLoss, self).__init__()

        self.discriminator = discriminator

        self.generator = generator

        self.adversarial_criterion = nn.BCELoss()

        ones = ops.Ones()

        zeros = ops.Zeros()

        self.real_lable = ones((16, 1), mstype.float32)

        self.fake_lable = zeros((16, 1), mstype.float32)

    def construct(self, HR_img, LR_img):

        """dloss"""

        hr = HR_img

        lr = LR_img

        # Generating fake high resolution images from real low resolution images.

        sr = self.generator(lr)

        # Let the discriminator realize that the sample is real.

        real_output = self.discriminator(hr)

        d_loss_real = self.adversarial_criterion(real_output, self.real_lable)

        # Let the discriminator realize that the sample is false.

        fake_output = self.discriminator(sr)

        d_loss_fake = self.adversarial_criterion(fake_output, self.fake_lable)

        d_loss = d_loss_fake+d_loss_real

        return  d_loss

class GeneratorLoss(nn.Cell):

    """Loss for generator"""

    def __init__(self, discriminator, generator, vgg_ckpt):

        super(GeneratorLoss, self).__init__()

        self.discriminator = discriminator

        self.generator = generator

        self.mse_loss = nn.MSELoss()

        self.adversarial_criterion = nn.BCELoss()

        ones = ops.Ones()

        self.real_lable = ones((16, 1), mstype.float32)

        self.meanshif = MeanShift()

        self.vgg = vgg19(vgg_ckpt)

        for p in self.meanshif.get_parameters():

            p.requires_grad = False

    def construct(self, HR_img, LR_img):

        """gloss"""

        # L2loss

        hr = HR_img

        lr = LR_img

        sr = self.generator(lr)

        L2_loss = self.mse_loss(sr, hr)

        # adversarialloss

        fake_output = self.discriminator(sr)

        adversarial_loss = self.adversarial_criterion(fake_output, self.real_lable)

        # vggloss

        hr = (hr+1.0)/2.0

        sr = (sr+1.0)/2.0

        hr = self.meanshif(hr)

        sr = self.meanshif(sr)

        hr_feat = self.vgg(hr)

        sr_feat = self.vgg(sr)

        percep_loss = self.mse_loss(hr_feat, sr_feat)

        g_loss = 0.006*percep_loss+1e-3*adversarial_loss+L2_loss

        return  g_loss

这篇关于SRGAN loss部分的mindspore代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/580566

相关文章

使用Python实现表格字段智能去重

《使用Python实现表格字段智能去重》在数据分析和处理过程中,数据清洗是一个至关重要的步骤,其中字段去重是一个常见且关键的任务,下面我们看看如何使用Python进行表格字段智能去重吧... 目录一、引言二、数据重复问题的常见场景与影响三、python在数据清洗中的优势四、基于Python的表格字段智能去重

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Nginx中location实现多条件匹配的方法详解

《Nginx中location实现多条件匹配的方法详解》在Nginx中,location指令用于匹配请求的URI,虽然location本身是基于单一匹配规则的,但可以通过多种方式实现多个条件的匹配逻辑... 目录1. 概述2. 实现多条件匹配的方式2.1 使用多个 location 块2.2 使用正则表达式

使用Apache POI在Java中实现Excel单元格的合并

《使用ApachePOI在Java中实现Excel单元格的合并》在日常工作中,Excel是一个不可或缺的工具,尤其是在处理大量数据时,本文将介绍如何使用ApachePOI库在Java中实现Excel... 目录工具类介绍工具类代码调用示例依赖配置总结在日常工作中,Excel 是一个不可或缺的工http://

SpringBoot实现导出复杂对象到Excel文件

《SpringBoot实现导出复杂对象到Excel文件》这篇文章主要为大家详细介绍了如何使用Hutool和EasyExcel两种方式来实现在SpringBoot项目中导出复杂对象到Excel文件,需要... 在Spring Boot项目中导出复杂对象到Excel文件,可以利用Hutool或EasyExcel

Python如何实现读取csv文件时忽略文件的编码格式

《Python如何实现读取csv文件时忽略文件的编码格式》我们再日常读取csv文件的时候经常会发现csv文件的格式有多种,所以这篇文章为大家介绍了Python如何实现读取csv文件时忽略文件的编码格式... 目录1、背景介绍2、库的安装3、核心代码4、完整代码1、背景介绍我们再日常读取csv文件的时候经常

Golang中map缩容的实现

《Golang中map缩容的实现》本文主要介绍了Go语言中map的扩缩容机制,包括grow和hashGrow方法的处理,具有一定的参考价值,感兴趣的可以了解一下... 目录基本分析带来的隐患为什么不支持缩容基本分析在 Go 底层源码 src/runtime/map.go 中,扩缩容的处理方法是 grow

Go 1.23中Timer无buffer的实现方式详解

《Go1.23中Timer无buffer的实现方式详解》在Go1.23中,Timer的实现通常是通过time包提供的time.Timer类型来实现的,本文主要介绍了Go1.23中Timer无buff... 目录Timer 的基本实现无缓冲区的实现自定义无缓冲 Timer 实现更复杂的 Timer 实现总结在

基于Python实现多语言朗读与单词选择测验

《基于Python实现多语言朗读与单词选择测验》在数字化教育日益普及的今天,开发一款能够支持多语言朗读和单词选择测验的程序,对于语言学习者来说无疑是一个巨大的福音,下面我们就来用Python实现一个这... 目录一、项目概述二、环境准备三、实现朗读功能四、实现单词选择测验五、创建图形用户界面六、运行程序七、

Vue中动态权限到按钮的完整实现方案详解

《Vue中动态权限到按钮的完整实现方案详解》这篇文章主要为大家详细介绍了Vue如何在现有方案的基础上加入对路由的增、删、改、查权限控制,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、数据库设计扩展1.1 修改路由表(routes)1.2 修改角色与路由权限表(role_routes)二、后端接口设计