本文主要是介绍一口一口吃掉yolov8(2),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
前面介绍了训练的第一个部分,也是大部分人在网上找得到的文章,但是后面2个部分应该是网上没有的资料了,希望大家喜欢。
0.数据
我的数据是一些栈板,主要是检测栈板的空洞,识别出空洞的位置和偏转角度。原图如下
我的标注
我用labelme标注,然后转为yolo格式,转换代码如下。
# coding=utf-8
import os
import syspath = os.path.dirname(__file__)
sys.path.append(path)'''
Author:Don
date:2022/8/3 11:49
desc:
'''
import os
import json
import glob
#输入口,就是你图片和json存放的那个文件,输出的txt也在这个文件夹里
labelme_dir=r"E:\2022\work\shchaiduo\image"def get_labelme_data(labelme_dir):with open(labelme_dir) as f:j=json.load(f)out_data=[]img_h =j["imageHeight"]img_w =j["imageWidth"]for shape in j["shapes"]:label=shape["label"]points=shape["points"]x,y,x2,y2=points[0][0],points[0][1],points[1][0],points[1][1]x_c=(x+x2)//2y_c=(y+y2)//2w=abs(x-x2)h=abs(y-y2)out_data.append([label,x_c,y_c,w,h])return img_h,img_w,out_datadef rename_Suffix(in_,mode=".txt"):in_=in_.split('.')return in_[0]+modedef make_yolo_data(in_dir):json_list=glob.glob(os.path.join(in_dir,'*.json'))for json_ in json_list:json_path=os.path.join(in_dir,json_)json_txt=rename_Suffix(json_)img_h,img_w,labelme_datas=get_labelme_data(json_path)with open(os.path.join(in_dir,json_txt),'w+') as f:for labelme_data in labelme_datas:label=labelme_data[0]x_c=labelme_data[1]/img_wy_c=labelme_data[2]/img_hw=labelme_data[3]/img_wh=labelme_data[4]/img_hf.write("{} {} {} {} {}\n".format(label,x_c,y_c,w,h))f.close()if __name__ == '__main__':make_yolo_data(labelme_dir)
images是图片
labels是标签 txt格式
具体的是下图, 0是标签标识,因为只有一个class 所以我的数据里第一个都是0,后面是对应孔洞的xywh,但是要除以图片的长宽,具体的看上面的标签转换代码。 因为一个托盘只有2个孔洞,所以我的一个txt 只有2组数据。
test是图片
1.训练前数据准备
因为我的数据是实际现场采集的,所以很多数据增强的技术并不需要(个人理解)。在工业上,最重要的是安全而不是精度。意思就是如果是正确的就是100%,如果是错误的就是0%,最好不存在误检,漏检是可以接受的。所以模型不建议有更好的泛化能力。最好是没见过的东西就直接报警处理,而不是给出大概的检测范围。所以我只用了v8中的aLbumentations api 其他的都去掉了。默认batch_size=1。
from pathlib import Path
import glob
import os
from torch.utils.data import Dataset
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
from PIL import Image, ImageOps
import random
import albumentations as A
import numpy as np
import torchNUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixesclass Albumentations:# YOLOv8 Albumentations class (optional, only used if package is installed)def __init__(self, p=1.0):self.p = pT = [A.Blur(p=0.01),A.MedianBlur(p=0.01),A.ToGray(p=0.01),A.CLAHE(p=0.01),A.RandomBrightnessContrast(p=0.0),A.RandomGamma(p=0.0),A.ImageCompression(quality_lower=75, p=0.0), ] # transformsself.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))def __call__(self, labels):im = labels["img"]cls = labels["cls"]if len(cls):if self.transform and random.random() < self.p:new = self.transform(image=im, bboxes=labels["bboxes"], class_labels=cls) # transformedlabels["img"] = self._format_img(new["image"])labels["cls"] = torch.tensor(new["class_labels"])labels["bboxes"] = torch.tensor(new["bboxes"])labels["batch_idx"] = torch.zeros(labels["cls"].shape[0])return labelsdef _format_img(self, img):if len(img.shape) < 3:img = np.expand_dims(img, -1)img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]).astype(np.float32)img = torch.from_numpy(img)return img# 读取数据集存储
def verify_image_label(args):im_file, lb_file = argstry:im = Image.open(im_file)im.verify() # PIL verifyshape = im.size # image sizeshape = (shape[1], shape[0]) # hwif im.format.lower() in ("jpg", "jpeg"):with open(im_file, "rb") as f:f.seek(-2, 2)if f.read() != b"\xff\xd9": # corrupt JPEGImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)# verify labelsif os.path.isfile(lb_file):with open(lb_file) as f:lb = [x.split() for x in f.read().strip().splitlines() if len(x)]lb = np.array(lb, dtype=np.float32)nl = len(lb)if nl:_, i = np.unique(lb, axis=0, return_index=True)if len(i) < nl: # duplicate row checklb = lb[i] # remove duplicateselse:lb = np.zeros((0, 5), dtype=np.float32)else:lb = np.zeros((0, 5), dtype=np.float32)lb = lb[:, :5]return im_file, lb, shapeexcept Exception as e:return [None, None, None]class YOLODataset(Dataset):def __init__(self, img_path, imgsz=640, augment=True):super(YOLODataset, self).__init__()self.img_path = img_pathself.imgsz = imgszself.augment = augmentself.im_files = self.get_img_files(self.img_path) # 读取图片self.labels = self.get_labels() # 读取labelself.ni = len(self.labels)# transformsself.transforms = Albumentations(p=1.0)def get_img_files(self, img_path):"""Read image files."""try:f = [] # image filesfor p in img_path if isinstance(img_path, list) else [img_path]:p = Path(p) # os-agnosticif p.is_dir(): # dirf += glob.glob(str(p / "**" / "*.*"), recursive=True)elif p.is_file(): # filewith open(p) as t:t = t.read().strip().splitlines()parent = str(p.parent) + os.sepf += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global pathim_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)except Exception as e:raise FileNotFoundError(f"Error loading data from") from ereturn im_filesdef img2label_paths(self, img_paths):# Define label paths as a function of image pathssa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substringsreturn [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]def get_labels(self):self.label_files = self.img2label_paths(self.im_files)cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")try:cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dictexcept (FileNotFoundError, AssertionError, AttributeError):cache, exists = self.cache_labels(cache_path), False # run cache opsreturn cache["labels"]def cache_labels(self, path=Path("./labels.cache")):# Cache dataset labels, check images and read shapesif path.exists():path.unlink() # remove *.cache file if existsx = {"labels": []}desc = f"Scanning {path.parent / path.stem}..."total = len(self.im_files)with ThreadPool(NUM_THREADS) as pool:results = pool.imap(func=verify_image_label,iterable=zip(self.im_files, self.label_files)) # im_file, lb, shapepbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)for im_file, lb, shape, in pbar:if im_file:x["labels"].append(dict(im_file=im_file,shape=shape,cls=lb[:, 0:1], # n, 1bboxes=lb[:, 1:], # n, 4segments=None,keypoints=None,normalized=True,bbox_format="xywh"))pbar.close()np.save(str(path), x) # save cache for next timereturn x
2. 训练中取数据
取数据,要实现len 和getitem函数 ,因为使用的是torch的dataset。因为我们要重写index ,所以重写了collate_fn函数
def __len__(self):return len(self.labels)def __getitem__(self, index):return self.transforms(self.get_label_info(index))def get_label_info(self, index):label = self.labels[index].copy()label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)return labeldef load_image(self, i):# Loads 1 image from dataset index 'i', returns (im, resized hw)f = self.im_files[i]im = cv2.imread(f) # BGRif im is None:raise FileNotFoundError(f"Image Not Found {f}")h0, w0 = im.shape[:2] # orig hwr = self.imgsz / max(h0, w0) # ratioif r != 1: # if sizes are not equalinterp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREAim = cv2.resize(im, (640, 512), interpolation=interp)return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized@staticmethoddef collate_fn(batch):new_batch = {}keys = batch[0].keys()values = list(zip(*[list(b.values()) for b in batch]))for i, k in enumerate(keys):value = values[i]if k == "img":value = torch.stack(value, 0)if k in ["bboxes", "cls"]:value = torch.cat(value, 0)new_batch[k] = valuenew_batch["batch_idx"] = list(new_batch["batch_idx"])for i in range(len(new_batch["batch_idx"])):new_batch["batch_idx"][i] += i # add target image index for build_targets()new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)return new_batch
3.整合数据
def seed_worker(worker_id):# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloaderworker_seed = torch.initial_seed() % 2 ** 32np.random.seed(worker_seed)random.seed(worker_seed)TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
img_path = "../datasets/kongdong/images"
dataset = YOLODataset(img_path=img_path, imgsz=640, augment=True)
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True,pin_memory=PIN_MEMORY,collate_fn=getattr(dataset, "collate_fn", None),worker_init_fn=seed_worker,generator=generator)
pbar = tqdm(enumerate(train_loader), total=1, bar_format=TQDM_BAR_FORMAT)
for i, batch in pbar:
我们for 循环取数据集 那么batch里面有什么呢。我们看一下
现在我们检测一下数据做了变换后是否正确
# 检测输入的数据图像对不对
def check_data(batch):img = batch["img"]labels = batch['bboxes'] # xywhlabels[:, 0] *= 640labels[:, 1] *= 512labels[:, 2] *= 640labels[:, 3] *= 512input_tensor = img.squeeze()# 从[0,1]转化为[0,255],再从CHW转为HWC,最后转为cv2input_tensor = input_tensor.permute(1, 2, 0).type(torch.uint8).numpy()# RGB转BRGinput_tensor = cv2.cvtColor(input_tensor, cv2.COLOR_RGB2BGR)for box in labels.int(): # xywhcv2.rectangle(input_tensor, (int(box[0] - box[2] / 2), int(box[1] - box[3] / 2)),(int(box[0] + box[2] / 2), int(box[1] + box[3] / 2)), (255, 0, 255), -1)cv2.imshow('img', input_tensor)cv2.waitKey(0)for i, batch in pbar:# Forwardwith torch.cuda.amp.autocast(False):check_data(batch)img = batch["img"]preds = model(img)
ok,正确的,
我们再看一下模型的输出是否正确
ok,和我们第一个文章上前向推理网络的输出大小一致。
这篇关于一口一口吃掉yolov8(2)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!