Keras高级--构建复杂的自定义Losses和Metrics

2023-12-20 10:38

本文主要是介绍Keras高级--构建复杂的自定义Losses和Metrics,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Advanced Keras — Constructing Complex Custom Losses and Metrics

    • 1. 背景
    • 2. 自定义loss函数
    • 3. 多参数loss/metric函数

翻译自:https://towardsdatascience.com/advanced-keras-constructing-complex-custom-losses-and-metrics-c07ca130a618
作者:Eyal Zakkay
转载请注明出处以及本文链接

本文将介绍一个简单技巧来在Keras中构建自定义loss函数,它可以接收除 y_truey_pred 之外的参数。

1. 背景

在Keras中编译模型时,我们为compile函数提供所需的loss与metrics。例如:

model.compile(loss=’mean_squared_error’, optimizer=’sgd’, metrics=‘acc’)

出于可读性目的,从现在开始关注loss函数。但是,大部分内容也适用于metrics。
从Keras的有关loss函数的文档中能得到:

你可以传递一个现有的损失函数名,或者一个 TensorFlow/Theano 符号函数。该符号函数为每个数据点返回一个标量,有以下两个参数:
y_true: 真实标签。TensorFlow/Theano 张量。
y_pred: 预测值。TensorFlow/Theano 张量,其 shape 与 y_true 相同。

所以如果我们想使用常用的loss函数,比如MSE或者 Categorical Cross-entropy,我们可以通过传入合适的参数来很容易实现。
Keras的文档中提供了可用loss和metrics的列表。

2. 自定义loss函数

当我们需要使用实现更复杂的loss函数或者metric时,就需要构造自定义函数并且传给model.compile
例如,构建自定义metric(来自Keras文档):

import keras.backend as Kdef mean_pred(y_true, y_pred):return K.mean(y_pred)model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy', mean_pred])

3. 多参数loss/metric函数

可能你已经注意到,loss函数必须只接受2个参数y_truey_pred,分别是target张量和模型output张量。但是如果我们希望loss/metric函数依赖于除这两个之外的其他张量呢?
为此,我们需要使用 function closure。首先创建一个loss函数(随意取参数)返回一个y_truey_pred的函数。
例如,如果我们想(由于某种原因)创建一个loss函数,其将第一层中所有activations的均方值添加到MSE:


# Build a model
inputs = Input(shape=(128,))
layer1 = Dense(64, activation='relu')(inputs)
layer2 = Dense(64, activation='relu')(layer1)
predictions = Dense(10, activation='softmax')(layer2)
model = Model(inputs=inputs, outputs=predictions)# Define custom loss
def custom_loss(layer):# Create a loss function that adds the MSE loss to the mean of all squared activations of a specific layerdef loss(y_true,y_pred):return K.mean(K.square(y_pred - y_true) + K.square(layer), axis=-1)# Return a functionreturn loss# Compile the model
model.compile(optimizer='adam',loss=custom_loss(layer), # Call the loss function with the selected layermetrics=['accuracy'])# train
model.fit(data, labels)  

注意,我们创建了一个函数(不限制参数的数量),它返回了一个合法的loss函数,该函数可以访问其封闭函数的参数。
一个更具体的例子:

假设设计一个变分自动编码器。希望模型能够从编码的潜在空间重建其输入。但还希望潜在空间中的编码(近似)正态分布。
前者的目标可以通过设计一个仅依赖于输入和期望输出y_truey_pred的重建损失来实现。对于后者,需要设计一个对潜在张量进行操作的损失项(例如,Kullback Leibler损失)。为了让你的loss函数能够访问这个中间张量,我们刚学到的技巧可以派上用场。
使用示例:

    def model_loss(self):"""" Wrapper function which calculates auxiliary values for the complete loss function.Returns a *function* which calculates the complete loss given only the input and target output """# KL losskl_loss = self.calculate_kl_loss# Reconstruction lossmd_loss_func = self.calculate_md_loss# KL weight (to be used by total loss and by annealing scheduler)self.kl_weight = K.variable(self.hps['kl_weight_start'], name='kl_weight')kl_weight = self.kl_weightdef seq2seq_loss(y_true, y_pred):""" Final loss calculation function to be passed to optimizer"""# Reconstruction lossmd_loss = md_loss_func(y_true, y_pred)# Full lossmodel_loss = kl_weight*kl_loss() + md_lossreturn model_lossreturn seq2seq_loss

此示例是Sequence to Sequence Variational Autoencoder模型的一部分,更多上下文和完整代码访问此

使用的Keras版本=2.2.4

Reference

[1]. Keras — Losses
[2]. Keras — Metrics
[3]. Github Issue — Passing additional arguments to objective function

Thanks to Ludovic Benistant.

这篇关于Keras高级--构建复杂的自定义Losses和Metrics的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/515734

相关文章

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

如何自定义Nginx JSON日志格式配置

《如何自定义NginxJSON日志格式配置》Nginx作为最流行的Web服务器之一,其灵活的日志配置能力允许我们根据需求定制日志格式,本文将详细介绍如何配置Nginx以JSON格式记录访问日志,这种... 目录前言为什么选择jsON格式日志?配置步骤详解1. 安装Nginx服务2. 自定义JSON日志格式各

Android自定义Scrollbar的两种实现方式

《Android自定义Scrollbar的两种实现方式》本文介绍两种实现自定义滚动条的方法,分别通过ItemDecoration方案和独立View方案实现滚动条定制化,文章通过代码示例讲解的非常详细,... 目录方案一:ItemDecoration实现(推荐用于RecyclerView)实现原理完整代码实现

使用Python和python-pptx构建Markdown到PowerPoint转换器

《使用Python和python-pptx构建Markdown到PowerPoint转换器》在这篇博客中,我们将深入分析一个使用Python开发的应用程序,该程序可以将Markdown文件转换为Pow... 目录引言应用概述代码结构与分析1. 类定义与初始化2. 事件处理3. Markdown 处理4. 转

基于Spring实现自定义错误信息返回详解

《基于Spring实现自定义错误信息返回详解》这篇文章主要为大家详细介绍了如何基于Spring实现自定义错误信息返回效果,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录背景目标实现产出背景Spring 提供了 @RestConChina编程trollerAdvice 用来实现 HTT

SpringSecurity 认证、注销、权限控制功能(注销、记住密码、自定义登入页)

《SpringSecurity认证、注销、权限控制功能(注销、记住密码、自定义登入页)》SpringSecurity是一个强大的Java框架,用于保护应用程序的安全性,它提供了一套全面的安全解决方案... 目录简介认识Spring Security“认证”(Authentication)“授权” (Auth

kotlin中的行为组件及高级用法

《kotlin中的行为组件及高级用法》Jetpack中的四大行为组件:WorkManager、DataBinding、Coroutines和Lifecycle,分别解决了后台任务调度、数据驱动UI、异... 目录WorkManager工作原理最佳实践Data Binding工作原理进阶技巧Coroutine

Java使用Mail构建邮件功能的完整指南

《Java使用Mail构建邮件功能的完整指南》JavaMailAPI是一个功能强大的工具,它可以帮助开发者轻松实现邮件的发送与接收功能,本文将介绍如何使用JavaMail发送和接收邮件,希望对大家有所... 目录1、简述2、主要特点3、发送样例3.1 发送纯文本邮件3.2 发送 html 邮件3.3 发送带