本文主要是介绍【深度学习实战(30)】训练框架之使用tensorboard记录loss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一、 安装Tensorboard库
pip install tensorflow
pip install tensorboardx
二、LossHistory类实现过程
1. init
构造函数
传入参数log
保存路径,模型,模型输入尺寸
def __init__(self, log_dir, model, input_shape):
实例化SummaryWriter
对象
self.writer = SummaryWriter(self.log_dir)
tensorboard.SummaryWriter.add_graph
记录model
try:dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass
训练结束后查看保存的模型
- 记录
loss
self.losses.append(loss)
self.val_loss.append(val_loss)
txt
文档记录loss
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")
训练完后,查看epoch_loss.txt
和epoch_val_loss.txt
tensorboard.SummaryWriter.add_scalar
记录loss
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
训练结束后查看保存的loss
pyplot
绘制loss
曲线图
def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")
查看loss
曲线图
三、LossHistory类完整代码
import os
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalclass LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir = log_dirself.losses = []self.val_loss = []os.makedirs(self.log_dir)self.writer = SummaryWriter(self.log_dir)try:# --------- 1. tensorboard.SummaryWriter.add_graph记录model -------------#dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1],use_strict_trace=False)self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)# --------- 2. 保存loss -------------#self.losses.append(loss)self.val_loss.append(val_loss)# --------- 3. txt记录loss -------------#with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")# --------- 4. tensorboard.SummaryWriter.add_scalar记录loss -------------#self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()# --------- 5. pyplot绘制loss曲线图 -------------#def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")
四、LossHistory类使用框架
import LossHistory# 构造loss_history类
loss_history = LossHistory(log_dir, model, (input_W, input_H))# 训练一轮,将训练,验证损失传进loss_history类
loss_history.append_loss(epoch + 1, train_loss / epoch_step, val_loss / epoch_step_val)# 根据loss_history类中保存的loss来保存最佳模型
if len(loss_history_.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history_.val_loss):best_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss':val_loss}torch.save(best_ckpt, os.path.join(save_dir, name_best_weights))# 训练一轮结束后,关闭Tensorboard.SummryWriter
loss_history.writer.close()
这篇关于【深度学习实战(30)】训练框架之使用tensorboard记录loss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!