TensorBoard在pytorch训练过程中如何使用,及数据读取问题解决方法

本文主要是介绍TensorBoard在pytorch训练过程中如何使用,及数据读取问题解决方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

TensorBoard

    • 模块导入
    • 日志记录文件的创建
    • 训练中如何写入数据
    • 如何提取保存的数据调用TensorBoard面板
    • 可能会遇到的问题

模块导入

首先从torch中导入tensorboard的SummaryWriter日志记录模块

from torch.utils.tensorboard import SummaryWriter

然后导入要用到的os库,当然你们也要导入自己模型训练需要用到的库

import os

日志记录文件的创建

import oslog_dir = 'runs/EfficientNet_B3_experiment2'# 检查目录是否存在
if os.path.exists(log_dir):# 如果目录存在,获取目录下的所有文件和子目录列表files = os.listdir(log_dir)# 遍历目录下的文件和子目录for file in files:# 拼接文件的完整路径file_path = os.path.join(log_dir, file)# 判断是否为文件if os.path.isfile(file_path):# 如果是文件,删除该文件os.remove(file_path)elif os.path.isdir(file_path):# 如果是目录,递归地删除目录及其下的所有文件和子目录for root, dirs, files in os.walk(file_path, topdown=False):for name in files:os.remove(os.path.join(root, name))for name in dirs:os.rmdir(os.path.join(root, name))os.rmdir(file_path)# 创建新的SummaryWriter
writer = SummaryWriter(log_dir)

这个代码会自动创建并更新日志文件目录,请谨慎使用,记得改
log_dir = 'runs/EfficientNet_B3_experiment2'路径名字小心把之前保存好的数据删除了
之后模型训练的数据将会写入到log_dir这个路径文件中,在由TensorBoard张量板调用显示数据

训练中如何写入数据

for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0start_time = time.time()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 记录学习率current_lr = optimizer.param_groups[0]['lr']writer.add_scalar('Learning Rate', current_lr, epoch)# 记录梯度范数total_norm = 0for p in model.parameters():param_norm = p.grad.data.norm(2)total_norm += param_norm.item() ** 2total_norm = total_norm ** 0.5writer.add_scalar('Gradient Norm', total_norm, epoch)running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_loss = running_loss / len(train_loader)train_accuracy = 100 * correct / total# 记录训练损失和准确率writer.add_scalar('Training Loss', train_loss, epoch)writer.add_scalar('Training Accuracy', train_accuracy, epoch)# 记录模型参数的直方图for name, param in model.named_parameters():writer.add_histogram(name, param, epoch)# 记录网络结构(通常只需要记录一次)if epoch == 0:writer.add_graph(model, images.to(device))# 记录输入图片img_grid = torchvision.utils.make_grid(images)writer.add_image('train_images', img_grid, epoch)# 使用matplotlib记录渲染的图片fig, ax = plt.subplots()ax.plot(np.arange(len(labels)), labels.cpu().numpy(), 'b', label='True')ax.plot(np.arange(len(predicted)), predicted.cpu().numpy(), 'r', label='Predicted')ax.legend()writer.add_figure('predictions vs. actuals', fig, epoch)# 验证模型model.eval()val_loss = 0.0correct = 0total = 0all_preds = []all_labels = []with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())total += labels.size(0)correct += (predicted == labels).sum().item()val_loss /= len(test_loader)val_accuracy = 100 * correct / totalif val_accuracy > best_val_accuracy:# 当新的最佳验证准确率出现时,保存模型状态字典best_val_accuracy = val_accuracybest_model_state_dict = model.state_dict()# 记录验证损失和准确率writer.add_scalar('Validation Loss', val_loss, epoch)writer.add_scalar('Validation Accuracy', val_accuracy, epoch)# 记录多条曲线writer.add_scalars('Loss', {'train': train_loss, 'val': val_loss}, epoch)writer.add_scalars('Accuracy', {'train': train_accuracy, 'val': val_accuracy}, epoch)# 打印每个epoch的训练和验证结果print(f'Epoch [{epoch+1}/{num_epochs}], 'f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, 'f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%, 'f'Time: {time.time() - start_time:.2f}s')

以上代码分别记录了
在这里插入图片描述
在这里插入图片描述

如何提取保存的数据调用TensorBoard面板

在终端输入以下代码

tensorboard --logdir='修改为自己的log_dir路径'

在这里插入图片描述
然后点击 http://localhost:6006/就可以成功加载面板了
在这里插入图片描述

可能会遇到的问题

如果数据读取失败那么请检查数据路径是否正确
注意数据文件中不能有任何中文

这篇关于TensorBoard在pytorch训练过程中如何使用,及数据读取问题解决方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1026870

相关文章

使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)

《使用Java解析JSON数据并提取特定字段的实现步骤(以提取mailNo为例)》在现代软件开发中,处理JSON数据是一项非常常见的任务,无论是从API接口获取数据,还是将数据存储为JSON格式,解析... 目录1. 背景介绍1.1 jsON简介1.2 实际案例2. 准备工作2.1 环境搭建2.1.1 添加

MySQL中删除重复数据SQL的三种写法

《MySQL中删除重复数据SQL的三种写法》:本文主要介绍MySQL中删除重复数据SQL的三种写法,文中通过代码示例讲解的非常详细,对大家的学习或工作有一定的帮助,需要的朋友可以参考下... 目录方法一:使用 left join + 子查询删除重复数据(推荐)方法二:创建临时表(需分多步执行,逻辑清晰,但会

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

C#读取本地网络配置信息全攻略分享

《C#读取本地网络配置信息全攻略分享》在当今数字化时代,网络已深度融入我们生活与工作的方方面面,对于软件开发而言,掌握本地计算机的网络配置信息显得尤为关键,而在C#编程的世界里,我们又该如何巧妙地读取... 目录一、引言二、C# 读取本地网络配置信息的基础准备2.1 引入关键命名空间2.2 理解核心类与方法

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

如何使用celery进行异步处理和定时任务(django)

《如何使用celery进行异步处理和定时任务(django)》文章介绍了Celery的基本概念、安装方法、如何使用Celery进行异步任务处理以及如何设置定时任务,通过Celery,可以在Web应用中... 目录一、celery的作用二、安装celery三、使用celery 异步执行任务四、使用celery

使用Python绘制蛇年春节祝福艺术图

《使用Python绘制蛇年春节祝福艺术图》:本文主要介绍如何使用Python的Matplotlib库绘制一幅富有创意的“蛇年有福”艺术图,这幅图结合了数字,蛇形,花朵等装饰,需要的可以参考下... 目录1. 绘图的基本概念2. 准备工作3. 实现代码解析3.1 设置绘图画布3.2 绘制数字“2025”3.3

详谈redis跟数据库的数据同步问题

《详谈redis跟数据库的数据同步问题》文章讨论了在Redis和数据库数据一致性问题上的解决方案,主要比较了先更新Redis缓存再更新数据库和先更新数据库再更新Redis缓存两种方案,文章指出,删除R... 目录一、Redis 数据库数据一致性的解决方案1.1、更新Redis缓存、删除Redis缓存的区别二

oracle数据库索引失效的问题及解决

《oracle数据库索引失效的问题及解决》本文总结了在Oracle数据库中索引失效的一些常见场景,包括使用isnull、isnotnull、!=、、、函数处理、like前置%查询以及范围索引和等值索引... 目录oracle数据库索引失效问题场景环境索引失效情况及验证结论一结论二结论三结论四结论五总结ora

JAVA中整型数组、字符串数组、整型数和字符串 的创建与转换的方法

《JAVA中整型数组、字符串数组、整型数和字符串的创建与转换的方法》本文介绍了Java中字符串、字符数组和整型数组的创建方法,以及它们之间的转换方法,还详细讲解了字符串中的一些常用方法,如index... 目录一、字符串、字符数组和整型数组的创建1、字符串的创建方法1.1 通过引用字符数组来创建字符串1.2