基于PyTorch的视频分类实战

2024-03-19 07:44
文章标签 实战 pytorch 视频分类

本文主要是介绍基于PyTorch的视频分类实战,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、数据集下载

官方链接:https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads

百度网盘连接:

https://pan.baidu.com/s/1sSn--u_oLvTDjH-BgOAv_Q?pwd=xsri

提取码: xsri 

        官方链接有详细的数据集介绍,下载的是压缩包 ‘hmdb51_org.rar’,解压后里面是 51 个.rar 压缩包,每个压缩包名是一个类别,里面的是对应类别的视频片段(.avi 文件)。因为资源有限,这里只解压了 5 个类别的视频如图 1 所示:

图1 'hmdb5/org'

        这里新建了 ‘hmdb5’ 文件夹,并新建了 ‘org’ 子文件夹,然后把 ‘hmdb51_org’ 文件夹的 5 个子文件夹放到 ‘org’ 中。作为这次实践的源视频数据。

2、utils.py
        在这里先实现 utils.py,即取帧(get_frames)和存帧(store_frames)函数,取帧函数的功能为从视频中等间距抽取 n_frame 帧,并返回这些帧组成的列表。存帧函数的功能即为将帧列表按序存到 path 中。

import os
import cv2
import numpy as npdef get_frames(path, n_frames=1):""":param path: 视频文件路径:param n_frames: 读取的帧数:return: 读取的帧列表 frames"""frames = []# 实例化一个用于捕获视频流的对象, 若参数为整数则用于读取摄像头视频, 若参数为字符串则用于读取视频文件v_cap = cv2.VideoCapture(path)'''cv2.CAP_PROP_FRAME_COUNT 是 cv2.VideoCapture 类的一个属性标识符,用于获取视频流或视频文件中的总帧数cv2.VideoCapture 的 get 方法用于获取视频流或视频文件的属性(返回值均为实数):propId 是属性标识符,整数:cv2.CAP_PROP_FRAME_WIDTH:视频的帧宽度(以像素为单位)cv2.CAP_PROP_FRAME_HEIGHT:视频的帧高度(以像素为单位)cv2.CAP_PROP_FPS:视频的帧率(每秒的帧数)cv2.CAP_PROP_POS_FRAMES:当前读取帧的位置(基于 0 的索引)cv2.CAP_PROP_POS_AVI_RATIO:视频文件的相对位置(播放进度)cv2.CAP_PROP_FRAME_COUNT:视频文件中的总帧数'''v_len = int(v_cap.get(propId=cv2.CAP_PROP_FRAME_COUNT))'''在指定区间返回等距的数字数组:start: 区间起点stop: 区间终点num: 采样数量endpoint: 默认为 True,若为 False 则区间不包括 stop'''frame_list = np.linspace(start=0, stop=v_len - 1, num=n_frames + 1, dtype=np.int16)for fn in range(v_len):# 读取下一帧。它返回两个值:一个布尔值 success 表示是否成功读取帧和一个数组 frame 表示读取到的帧。success, frame = v_cap.read()if success is False:continueif fn in frame_list:frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)frames.append(frame)v_cap.release()return framesdef store_frames(frames, path):""":param frames: 待保存为 jpg 图片的帧列表:param path: 存储路径:return:"""for i, frame in enumerate(frames):frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)path2img = os.path.join(path, "frame" + str(i) + ".jpg")cv2.imwrite(path2img, frame)

3、数据抽帧,并划分训练集和测试集

        先在 ‘hmdb5’ 文件夹中新建子文件夹 ‘train’ 和 ‘test',再运行以下代码即可数据抽帧,并划分训练集和测试集。       

import os
from utils import get_frames, store_framespath = "hmdb5"
org_dir = "org"
org_path = os.path.join(path, org_dir)
categories_list = os.listdir(org_path)
# brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4# 输出每个类别的视频数量
for c in categories_list:print("category:", c)p = os.path.join(org_path, c)video_list = os.listdir(p)print("number of videos:", len(video_list))print("-" * 50)
"""
category: brush_hair
number of videos: 107
--------------------------------------------------
category: cartwheel
number of videos: 107
--------------------------------------------------
category: clap
number of videos: 130
--------------------------------------------------
category: catch
number of videos: 102
--------------------------------------------------
category: chew
number of videos: 109
--------------------------------------------------
"""extension = '.avi'
n_frames = 16
train_rate = 0.9for i, c in enumerate(categories_list):p = os.path.join(org_path, c)videos = [v for v in os.listdir(p) if v.endswith(extension)]train_size = int(len(videos) * train_rate)for j, name in enumerate(videos):video_path = os.path.join(p, name)frames = get_frames(video_path, n_frames=n_frames)path2store = os.path.join(path, "train")if j >= train_size:path2store = os.path.join(path, "test")path2store = os.path.join(path2store, str(i)+"_"+name[:-4])print(path2store)os.makedirs(path2store, exist_ok=True)store_frames(frames, path2store)

        第一段代码输出五个类别的视频数量,可以看到 brush_hair、cartwheel、clap、catch 和  chew 依次有 107、107、130、102、109 个视频。最后一段代码的功能是依次对每个类别的每个视频抽帧,并将抽帧结果存置指定路径,同时划分训练集和测试集。这里设置每个视频的抽帧数量 n_frame=16,按 9:1(498:57) 划分训练集和测试集。每个样本(视频文件夹)名都在原来的名字前拼接上 ‘类别编号_’,其中类别编号为:

brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

        这段代码的运行结果如图 2 所示(以测试集为例)即所有类别样本都在一个文件夹中,不再有类别目录,样本名字最前面的数字即为该样本的类别。

图2 测试集部分样本

4、train.py

4.1 导包

import os
import re
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torchvision.models import video
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

4.2 设置环境变量

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4.3 定义超参数

lr = 3e-5
gamma = 0.5
epochs = 20
step_size = 5
batch_size = 16
weight_decay = 1e-2

        这里定义初始学习率为 lr=3e-5,训练轮次为 epochs=20,batch_size=16,正则化系数为 weight_decay=1e-2。gamma 和 step_size 分为 torch.optim.lr_scheduler.ReduceLROnPlateau 类构造函数的入参 factor 和 patience。factor 是学习率降低的因子,新的学习率将是当前学习率乘以这个因子;patience 指观察验证指标在多少个 epoch 内没有改善后降低学习率。

4.4 定义图像变换函数

train_transform = transforms.Compose([transforms.Resize((112, 112)),transforms.RandomHorizontalFlip(p=0.5),# 用于对图像进行随机的仿射变换, degrees 为旋转角度, translate 为水平和垂直平移的最大绝对分数transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ToTensor(),transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170])
])test_transform = transforms.Compose([transforms.Resize((112, 112)),transforms.ToTensor(),transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170]),
])

        这里定义了两个图像变换函数,即用于训练集的 train_transform 和用于测试集的 test_transform, train_transform 在训练前依次对图片进行 resize 操作,以 0.5 的概率水平镜像变换操作,随机仿射操作(随机沿 x,y 方向分别平移 (-0.1*w,0.1*w)、(-0.1*h,0.1*h)),转换为 tensor 操作和标准化操作。test_transform 相较于 train_transform 去掉了起数据增强作用的两个操作。

4.5 定义训练集和测试集路径

# 训练集(498):测试机(57)=9:1
train_dir = 'hmdb5/train'
test_dir = 'hmdb5/test'

4.6 定义数据集类

class HMDB5Dataset(Dataset):def __init__(self, directory, transform):self.dir = directoryself.transform = transformself.names = os.listdir(directory)def __len__(self):return len(self.names)def __getitem__(self, idx):path = os.path.join(self.dir, self.names[idx])frames = []for i in range(16):frame = Image.open(os.path.join(path, 'frame' + str(i) + '.jpg'))frames.append(self.transform(frame))frames = torch.stack(frames)# 返回 input 的转置版本, 即交换 input 的 dim0 和 dim1frames = torch.transpose(input=frames, dim0=0, dim1=1)# 编译正则表达式, ^ 表示匹配字符串的开始, + 表示一个或多个pattern = re.compile(r'^(\d+)_')match = re.search(pattern, self.names[idx])return frames, int(match.group(1))

        数据集类的构造函数定义了 3 个属性:dir(数据集路径)、transform(数据预处理方式)和names(样本名列表)。

        __getitem__ 函数根据 idx 按序取出一个样本的所有帧,并对所有帧执行了 transform 操作,最后返回的样本 frames 是 shape 为(channels,n_frames,h,w)的 tensor,该函数还利用 re 库从样本名中获取该样本的标签并返回。

4.7 定义模型

def init_model(mi):m = Noneif mi == 1:m = video.r3d_18(num_classes=5)  # epochs = 20, correct = 0.754return m.to(device)

        这里使用的模型为  torchvision.models.video.r3d_18[1],原文链接:https://arxiv.org/abs/1711.11248。实现可以参考 torch 源码。

4.8 计算评价指标

def correct_loss(data_loader, desc, test):results = []correct = 0.0test_loss = 0.0for img, tag in tqdm(data_loader, desc, total=len(data_loader)):img = img.to(device)tag = tag.to(device)pre = model(img)if test:test_loss += loss_fn(pre, tag)correct += torch.sum((pre.argmax(dim=1) == tag).float())results.append(correct / len(data_loader.dataset))if test:results.append(test_loss)return results

        correct_loss 函数用于计算 model 在 data_loader 上的 correct 和 loss(如果 test=True ,即data_loader 是测试集的数据加载器)。并将结果以列表的形式返回。

4.9 训练

if __name__ == '__main__':model = init_model(1)loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=weight_decay)train_ds = HMDB5Dataset(train_dir, train_transform)test_ds = HMDB5Dataset(test_dir, test_transform)train_dl = DataLoader(train_ds, batch_size, True, num_workers=2)test_dl = DataLoader(test_ds, batch_size, False, num_workers=2)'''在验证指标停止改善时降低学习率:mode(str): 值域为 {'min', 'max'}。指定优化器应该监视的指标是应该最小化还是最大化factor(float): 学习率降低的因子。新的学习率将是当前学习率乘以这个因子patience(int): 观察验证指标在多少个 epoch 内没有改善后降低学习率'''scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=gamma, patience=step_size, verbose=True)best_loss = float('inf')for epoch in range(epochs):s_loss = 0.0print('Epoch:', epoch + 1, '/', epochs)for x, y in tqdm(train_dl, total=len(train_dl)):x = x.to(device)y = y.to(device)pred = model(x)loss = loss_fn(pred, y)s_loss += lossloss.backward()optimizer.step()optimizer.zero_grad()model.eval()  # 将模型设置为评估模式with torch.no_grad():print("s_loss:%.3f" % s_loss)train_metrics = correct_loss(train_dl, 'compute train_metrics:', False)test_metrics = correct_loss(test_dl, 'compute test_metrics:', True)if test_metrics[1] < best_loss:best_loss = test_metrics[1]print("train_correct:%.3f,test_correct:%.3f" % (train_metrics[0], test_metrics[0]))model.train()scheduler.step(best_loss)

        这里使用交叉墒损失函数,AdamW 优化器,学习率使用 ReduceLROnPlateau scheduler,该 scheduler 监视的指标为 test loss。训练过程中得到的最高 test_correct=0.754。

5、项目目录结构

参考文献

[1] Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, and Manohar Paluri. A closer look at spatiotemporal convolutions for action recognition. In CVPR, pages 6450–6459, 2018. 

这篇关于基于PyTorch的视频分类实战的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/825266

相关文章

Golang操作DuckDB实战案例分享

《Golang操作DuckDB实战案例分享》DuckDB是一个嵌入式SQL数据库引擎,它与众所周知的SQLite非常相似,但它是为olap风格的工作负载设计的,DuckDB支持各种数据类型和SQL特性... 目录DuckDB的主要优点环境准备初始化表和数据查询单行或多行错误处理和事务完整代码最后总结Duck

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1

Golang使用minio替代文件系统的实战教程

《Golang使用minio替代文件系统的实战教程》本文讨论项目开发中直接文件系统的限制或不足,接着介绍Minio对象存储的优势,同时给出Golang的实际示例代码,包括初始化客户端、读取minio对... 目录文件系统 vs Minio文件系统不足:对象存储:miniogolang连接Minio配置Min

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

滚雪球学Java(87):Java事务处理:JDBC的ACID属性与实战技巧!真有两下子!

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE啦,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好习惯,别被干货淹没了哦~ 🏆本文收录于「滚雪球学Java」专栏,专业攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎大家关注&&收藏!持续更新中,up!up!up!! 环境说明:Windows 10

springboot实战学习(1)(开发模式与环境)

目录 一、实战学习的引言 (1)前后端的大致学习模块 (2)后端 (3)前端 二、开发模式 一、实战学习的引言 (1)前后端的大致学习模块 (2)后端 Validation:做参数校验Mybatis:做数据库的操作Redis:做缓存Junit:单元测试项目部署:springboot项目部署相关的知识 (3)前端 Vite:Vue项目的脚手架Router:路由Pina:状态管理Eleme