DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛

2024-08-26 07:36

本文主要是介绍DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛

  • baseline分析
    • 构建YOLO数据集
    • 开始训练
  • 优化思路

话不多说直接开始

baseline分析

这里我们忽略数据、模型下载的单元格
导入数据处理的一些包

import os, sys
import cv2, glob, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

读取下载的数据,并查看一下json的格式。

train_anno = json.load(open('训练集(有标注第一批)/标注/45.json', encoding='utf-8'))
train_anno[0], len(train_anno)

用pandas读取数据查看数据格式

pd.read_json('训练集(有标注第一批)/标注/45.json')

读取视频,使用VideoCapture对数据进行切帧处理。

video_path = '训练集(有标注第一批)/视频/45.mp4'
cap = cv2.VideoCapture(video_path)
while True:# 读取下一帧ret, frame = cap.read()if not ret:breakbreak  

根据json的信息,展示一张画框的图片

bbox = [746, 494, 988, 786]pt1 = (bbox[0], bbox[1])
pt2 = (bbox[2], bbox[3])color = (0, 255, 0) 
thickness = 2  # 线条粗细cv2.rectangle(frame, pt1, pt2, color, thickness)frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
plt.imshow(frame)

截止到这里,上面其实都可以忽略,上面就是让大家看一下原始数据这个格式是什么样,大概该怎么处理这个数据。接下来开始构建YOLO所需的数据集。

构建YOLO数据集

yolo数据集的格式为一个data文件夹下包含三个内容,train; val; yolo.yaml,其中train和val不在介绍,yolo.yaml主要包含数据涉及到的标签信息。
我这里是吧数据放在/root/data文件夹下了,因为切帧的图片数据很多需要的空间后的云的系统盘空间不够。大家可以参考。

if not os.path.exists('/root/data/yolo-dataset/'):os.mkdir('/root/data/yolo-dataset/')
if not os.path.exists('/root/data/yolo-dataset/train'):os.mkdir('/root/data/yolo-dataset/train')
if not os.path.exists('/root/data/yolo-dataset/val'):os.mkdir('/root/data/yolo-dataset/val')dir_path = os.path.abspath('./') + '/'# 需要按照你的修改path
with open('/root/data/yolo-dataset/yolo.yaml', 'w', encoding='utf-8') as up:up.write(f'''
path: /root/data/yolo-dataset/
train: train/
val: val/names:0: 非机动车违停1: 机动车违停2: 垃圾桶满溢3: 违法经营
''')

对获取的文件路径进行排序,以确保标注文件和视频文件按照相同顺序匹配。

train_annos = glob.glob('训练集(有标注第一批)/标注/*.json')
train_videos = glob.glob('训练集(有标注第一批)/视频/*.mp4')
train_annos.sort(); train_videos.sort();category_labels = ["非机动车违停", "机动车违停", "垃圾桶满溢", "违法经营"]

我这里按照8:2划分训练集和验证集,一共应该是52组数据,划分之后是42:10
下面在代码中给出了逐行的注释,大家自行食用即可。

for anno_path, video_path in zip(train_annos[:42], train_videos[:42]):print(video_path)# 使用Pandas读取JSON格式的标注文件,返回一个DataFrame对象anno_df = pd.read_json(anno_path)   # 使用OpenCV打开视频文件,准备逐帧读取cap = cv2.VideoCapture(video_path)frame_idx = 0 # 读取视频帧while True:ret, frame = cap.read()if not ret:break# 获取当前帧的高度和宽度img_height, img_width = frame.shape[:2]# 从标注文件中提取当前帧的标注信息frame_anno = anno_df[anno_df['frame_id'] == frame_idx]# 将当前帧保存为JPEG图像文件cv2.imwrite('/root/data/yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)# 检查当前帧有没有标注信息if len(frame_anno) != 0:# 创建并打开一个与当前帧图像同名的文本文件,准备写入YOLO格式的标签with open('/root/data/yolo-dataset/train/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):# 获取当前标注对象类别的索引category_idx = category_labels.index(category)# 获取框的坐标x_min, y_min, x_max, y_max = bbox# 计算标注框的中心点横纵坐标,并归一化到 [0, 1] 之间x_center = (x_min + x_max) / 2 / img_widthy_center = (y_min + y_max) / 2 / img_height# 计算框的宽和高,并归一化width = (x_max - x_min) / img_widthheight = (y_max - y_min) / img_heightif x_center > 1:print(bbox)# 将YOLO格式的标注信息写入标签文件up.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')# 处理下一帧frame_idx += 1

构建验证集,这部分代码直接看上一个即可一样的基本都是

for anno_path, video_path in zip(train_annos[-10:], train_videos[-10:]):print(video_path)anno_df = pd.read_json(anno_path)cap = cv2.VideoCapture(video_path)frame_idx = 0 while True:ret, frame = cap.read()if not ret:breakimg_height, img_width = frame.shape[:2]frame_anno = anno_df[anno_df['frame_id'] == frame_idx]cv2.imwrite('/root/data/yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.jpg', frame)if len(frame_anno) != 0:with open('/root/data/yolo-dataset/val/' + anno_path.split('/')[-1][:-5] + '_' + str(frame_idx) + '.txt', 'w') as up:for category, bbox in zip(frame_anno['category'].values, frame_anno['bbox'].values):category_idx = category_labels.index(category)x_min, y_min, x_max, y_max = bboxx_center = (x_min + x_max) / 2 / img_widthy_center = (y_min + y_max) / 2 / img_heightwidth = (x_max - x_min) / img_widthheight = (y_max - y_min) / img_heightup.write(f'{category_idx} {x_center} {y_center} {width} {height}\n')frame_idx += 1

开始训练

baseline使用的是yolov8n进行训练,在这里epoch代表训练的轮数,imgsz代表输入模型图像大小,batch代表一次梯度更新使用多少张图片

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"import warnings
warnings.filterwarnings('ignore')from ultralytics import YOLO
model = YOLO("yolov8n.pt")
results = model.train(data="/root/data/yolo-dataset/yolo.yaml", epochs=15, imgsz=1080, batch=16)

创建结果目录

category_labels = ["非机动车违停", "机动车违停", "垃圾桶满溢", "违法经营"]if not os.path.exists('result/'):os.mkdir('result')

对测试集视频文件的处理,通过预训练的YOLO模型对每个视频的每一帧进行检测,并将检测结果保存为JSON格式的文件。

from ultralytics import YOLO
# 使用训练好的模型进行预测
model = YOLO("runs/detect/train/weights/best.pt")
import globfor path in glob.glob('测试集/*.mp4'):# 保存结果生成的json文件submit_json = []# 对视频文件进行推理,conf=0.05设置了最低置信度阈值results = model(path, conf=0.05, imgsz=1080,  verbose=False)for idx, result in enumerate(results):boxes = result.boxes  # Boxes object for bounding box outputsmasks = result.masks  # Masks object for segmentation masks outputskeypoints = result.keypoints  # Keypoints object for pose outputsprobs = result.probs  # Probs object for classification outputsobb = result.obb  # Oriented boxes object for OBB outputsif len(boxes.cls) == 0:continue# 获取检测框的坐标、类别、置信度xywh = boxes.xyxy.data.cpu().numpy().round()cls = boxes.cls.data.cpu().numpy().round()conf = boxes.conf.data.cpu().numpy()# 写入submitfor i, (ci, xy, confi) in enumerate(zip(cls, xywh, conf)):submit_json.append({'frame_id': idx,'event_id': i+1,'category': category_labels[int(ci)],'bbox': list([int(x) for x in xy]),"confidence": float(confi)})# 保存json文件with open('./result/' + path.split('/')[-1][:-4] + '.json', 'w', encoding='utf-8') as up:json.dump(submit_json, up, indent=4, ensure_ascii=False)

优化思路

这篇关于DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1107941

相关文章

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

这15个Vue指令,让你的项目开发爽到爆

1. V-Hotkey 仓库地址: github.com/Dafrok/v-ho… Demo: 戳这里 https://dafrok.github.io/v-hotkey 安装: npm install --save v-hotkey 这个指令可以给组件绑定一个或多个快捷键。你想要通过按下 Escape 键后隐藏某个组件,按住 Control 和回车键再显示它吗?小菜一碟: <template

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

中文分词jieba库的使用与实景应用(一)

知识星球:https://articles.zsxq.com/id_fxvgc803qmr2.html 目录 一.定义: 精确模式(默认模式): 全模式: 搜索引擎模式: paddle 模式(基于深度学习的分词模式): 二 自定义词典 三.文本解析   调整词出现的频率 四. 关键词提取 A. 基于TF-IDF算法的关键词提取 B. 基于TextRank算法的关键词提取

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

Hadoop企业开发案例调优场景

需求 (1)需求:从1G数据中,统计每个单词出现次数。服务器3台,每台配置4G内存,4核CPU,4线程。 (2)需求分析: 1G / 128m = 8个MapTask;1个ReduceTask;1个mrAppMaster 平均每个节点运行10个 / 3台 ≈ 3个任务(4    3    3) HDFS参数调优 (1)修改:hadoop-env.sh export HDFS_NAMENOD

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na