【深度学习实战(30)】训练框架之使用tensorboard记录loss

2024-05-02 06:36

本文主要是介绍【深度学习实战(30)】训练框架之使用tensorboard记录loss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、 安装Tensorboard库

pip install tensorflow 
pip install tensorboardx

二、LossHistory类实现过程

1. init构造函数
传入参数log保存路径,模型,模型输入尺寸

def __init__(self, log_dir, model, input_shape):

实例化SummaryWriter对象

self.writer     = SummaryWriter(self.log_dir)
  1. tensorboard.SummaryWriter.add_graph记录model
 try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass

训练结束后查看保存的模型
在这里插入图片描述
在这里插入图片描述

  1. 记录loss
self.losses.append(loss)
self.val_loss.append(val_loss)
  1. txt文档记录loss
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")

训练完后,查看epoch_loss.txtepoch_val_loss.txt
在这里插入图片描述
在这里插入图片描述

  1. tensorboard.SummaryWriter.add_scalar记录loss
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

训练结束后查看保存的loss
在这里插入图片描述
在这里插入图片描述

  1. pyplot绘制loss曲线图
def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")

查看loss曲线图
在这里插入图片描述

三、LossHistory类完整代码

import os
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalclass LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:# --------- 1. tensorboard.SummaryWriter.add_graph记录model -------------#dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1],use_strict_trace=False)self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)# --------- 2. 保存loss -------------#self.losses.append(loss)self.val_loss.append(val_loss)# --------- 3. txt记录loss -------------#with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")# --------- 4. tensorboard.SummaryWriter.add_scalar记录loss -------------#self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()# --------- 5. pyplot绘制loss曲线图 -------------#def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")

四、LossHistory类使用框架

import LossHistory# 构造loss_history类
loss_history = LossHistory(log_dir, model, (input_W, input_H))# 训练一轮,将训练,验证损失传进loss_history类
loss_history.append_loss(epoch + 1, train_loss / epoch_step, val_loss / epoch_step_val)# 根据loss_history类中保存的loss来保存最佳模型
if len(loss_history_.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history_.val_loss):best_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss':val_loss}torch.save(best_ckpt, os.path.join(save_dir, name_best_weights))# 训练一轮结束后,关闭Tensorboard.SummryWriter
loss_history.writer.close()

这篇关于【深度学习实战(30)】训练框架之使用tensorboard记录loss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/953596

相关文章

Python在二进制文件中进行数据搜索的实战指南

《Python在二进制文件中进行数据搜索的实战指南》在二进制文件中搜索特定数据是编程中常见的任务,尤其在日志分析、程序调试和二进制数据处理中尤为重要,下面我们就来看看如何使用Python实现这一功能吧... 目录简介1. 二进制文件搜索概述2. python二进制模式文件读取(rb)2.1 二进制模式与文本

Django调用外部Python程序的完整项目实战

《Django调用外部Python程序的完整项目实战》Django是一个强大的PythonWeb框架,它的设计理念简洁优雅,:本文主要介绍Django调用外部Python程序的完整项目实战,文中通... 目录一、为什么 Django 需要调用外部 python 程序二、三种常见的调用方式方式 1:直接 im

C#中checked关键字的使用小结

《C#中checked关键字的使用小结》本文主要介绍了C#中checked关键字的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录✅ 为什么需要checked? 问题:整数溢出是“静默China编程”的(默认)checked的三种用

C#中预处理器指令的使用小结

《C#中预处理器指令的使用小结》本文主要介绍了C#中预处理器指令的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录 第 1 名:#if/#else/#elif/#endif✅用途:条件编译(绝对最常用!) 典型场景: 示例

C++ 右值引用(rvalue references)与移动语义(move semantics)深度解析

《C++右值引用(rvaluereferences)与移动语义(movesemantics)深度解析》文章主要介绍了C++右值引用和移动语义的设计动机、基本概念、实现方式以及在实际编程中的应用,... 目录一、右值引用(rvalue references)与移动语义(move semantics)设计动机1

SpringBoot整合 Quartz实现定时推送实战指南

《SpringBoot整合Quartz实现定时推送实战指南》文章介绍了SpringBoot中使用Quartz动态定时任务和任务持久化实现多条不确定结束时间并提前N分钟推送的方案,本文结合实例代码给大... 目录前言一、Quartz 是什么?1、核心定位:解决什么问题?2、Quartz 核心组件二、使用步骤1

Mysql中RelayLog中继日志的使用

《Mysql中RelayLog中继日志的使用》MySQLRelayLog中继日志是主从复制架构中的核心组件,负责将从主库获取的Binlog事件暂存并应用到从库,本文就来详细的介绍一下RelayLog中... 目录一、什么是 Relay Log(中继日志)二、Relay Log 的工作流程三、Relay Lo

使用Redis实现会话管理的示例代码

《使用Redis实现会话管理的示例代码》文章介绍了如何使用Redis实现会话管理,包括会话的创建、读取、更新和删除操作,通过设置会话超时时间并重置,可以确保会话在用户持续活动期间不会过期,此外,展示了... 目录1. 会话管理的基本概念2. 使用Redis实现会话管理2.1 引入依赖2.2 会话管理基本操作

Springboot请求和响应相关注解及使用场景分析

《Springboot请求和响应相关注解及使用场景分析》本文介绍了SpringBoot中用于处理HTTP请求和构建HTTP响应的常用注解,包括@RequestMapping、@RequestParam... 目录1. 请求处理注解@RequestMapping@GetMapping, @PostMappin

springboot3.x使用@NacosValue无法获取配置信息的解决过程

《springboot3.x使用@NacosValue无法获取配置信息的解决过程》在SpringBoot3.x中升级Nacos依赖后,使用@NacosValue无法动态获取配置,通过引入SpringC... 目录一、python问题描述二、解决方案总结一、问题描述springboot从2android.x