使用PyTorch AlexNet预训练模型对新数据集进行训练及预测

2024-08-28 21:36

本文主要是介绍使用PyTorch AlexNet预训练模型对新数据集进行训练及预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

      在 https://blog.csdn.net/fengbingchun/article/details/112709281 中介绍了AlexNet网络,这里使用PyTorch中提供的AlexNet预训练模型对新数据集进行训练,然后使用生成的模型进行预测。主要包括三部分:新数据集自动拆分、训练、预测

      1.新数据集自动拆分:这里使用从网络下载的西瓜、冬瓜图像作为新数据集,只有2类。西瓜存放在单独的目录下,名为watermelon;冬瓜存放在单独的目录下,名为wintermelon。图像总数为264张。

      以下为自动拆分数据集的实现代码:

import cv2
import os
import random
import shutil
import numpy as npclass SplitClassifyDataset:"""split the classification dataset"""def __init__(self, path_src, path_dst, ratios=(0.8, 0.1, 0.1)):"""path_src: source dataset pathpath_dst: the path to the split datasetratios: they are the ratios of train set, validation set, and test set, respectively """assert len(ratios) == 3, f"the length of ratios is not 3: {len(ratios)}"assert abs(ratios[0] + ratios[1] + ratios[2] - 1) < 1e-05, f"ratios sum must be 1: {ratios[0]}, {ratios[1]}, {ratios[2]}"self.path_src = path_srcself.path_dst = path_dstself.ratio_train = ratios[0]self.ratio_val = ratios[1]self.ratio_test = ratios[2]self.is_resize = Falseself.fill_value = Noneself.shape = Noneself.length_total = Noneself.classes = Noneself.mean = Noneself.std = Noneself.supported_img_formats = (".bmp", ".jpeg", ".jpg", ".png", ".webp")def resize(self, value=(114,114,114), shape=(256,256)):"""value: fill valueshape: the scaled shape"""self.is_resize = Trueself.fill_value = valueself.shape = shapedef _create_dir(self):self.classes = [name for name in os.listdir(self.path_src) if os.path.isdir(os.path.join(self.path_src, name))]for name in self.classes:directory = self.path_dst + "/train/" + nameif os.path.exists(directory):raise ValueError(f"{directory} directory already exists, delete it")os.makedirs(directory, exist_ok=True)directory = self.path_dst + "/val/" + nameif os.path.exists(directory):raise ValueError(f"{directory} directory already exists, delete it")os.makedirs(directory, exist_ok=True)if self.ratio_test != 0:directory = self.path_dst + "/test/" + nameif os.path.exists(directory):raise ValueError(f"{directory} directory already exists, delete it")os.makedirs(directory, exist_ok=True)def _get_images(self):image_names = {}self.length_total = 0for class_name in self.classes:imgs = []for root, dirs, files in os.walk(os.path.join(self.path_src, class_name)):for file in files:_, extension = os.path.splitext(file)if extension in self.supported_img_formats:imgs.append(file)else:print(f"Warning: {self.path_src+'/'+class_name+'/'+file} is an unsupported file")image_names[class_name] = imgsself.length_total += len(imgs)return image_namesdef _get_random_sequence(self, image_names):image_sequences = {}for name in self.classes:length = len(image_names[name])numbers = list(range(0, length))train_sequence = random.sample(numbers, int(length*self.ratio_train))# print("train_sequence:", train_sequence)val_sequence = [x for x in numbers if x not in train_sequence]if self.ratio_test != 0:val_sequence = random.sample(val_sequence, int(length*self.ratio_val))# print("val_sequence:", val_sequence)test_sequence = [x for x in numbers if x not in train_sequence and x not in val_sequence]# print("test_sequence:", test_sequence)else:test_sequence = []image_sequences[name] = [train_sequence, val_sequence, test_sequence]return image_sequencesdef _letterbox(self, img):shape = img.shape[:2] # current shape: [height, width, channel]new_shape = [self.shape[0], self.shape[1]]# scale ratio (new / old)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])# compute paddingnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh paddingdw /= 2 # divide padding into 2 sidesdh /= 2if shape[::-1] != new_unpad: # resizeimg = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.fill_value) # add borderreturn imgdef _copy_image(self):image_names = self._get_images()image_sequences = self._get_random_sequence(image_names) # train, val, testsum = 0for name in self.classes:for i in range(3):sum += len(image_sequences[name][i])assert self.length_total == sum, f"the length before and afeter the split must be equal: {self.length_total}:{sum}"for name in self.classes:dirname = ["train", "val", "test"]index = [0, 1, 2]if self.ratio_test == 0:index = [0, 1]for idx in index:for i in image_sequences[name][idx]:image_name = self.path_src + "/" + name + "/" + image_names[name][i]dst_dir_name =self.path_dst + "/" + dirname[idx] + "/" + name# print(image_name)if not self.is_resize: # only copyshutil.copy(image_name, dst_dir_name)else: # resize, scale the image proportionallyimg = cv2.imread(image_name) # BGRif img is None:raise FileNotFoundError(f"image not found: {image_name}")img = self._letterbox(img)cv2.imwrite(dst_dir_name+"/"+image_names[name][i], img)def _cal_mean_std(self):imgs = []std_reds = []std_greens = []std_blues = []for name in self.classes:dst_dir_name = self.path_dst + "/train/" + name + "/"for root, dirs, files in os.walk(dst_dir_name):for file in files:# print("file:", dst_dir_name+file)img = cv2.imread(dst_dir_name+file)if img is None:raise FileNotFoundError(f"image not found: {dst_dir_name}{file}")imgs.append(img)img_array = np.array(img)std_reds.append(np.std(img_array[:,:,0]))std_greens.append(np.std(img_array[:,:,1]))std_blues.append(np.std(img_array[:,:,2]))arr = np.array(imgs)# print("arr.shape:", arr.shape)self.mean = np.mean(arr, axis=(0, 1, 2)) / 255self.std = [np.mean(std_reds) / 255, np.mean(std_greens) / 255, np.mean(std_blues) / 255] # B,G,Rdef __call__(self):self._create_dir()self._copy_image()self._cal_mean_std()def get_mean_std(self):"""get the mean and variance"""return self.mean, self.stdif __name__ == "__main__":split = SplitClassifyDataset(path_src="../../data/database/classify/melon", path_dst="datasets/melon_new_classify")split.resize(shape=(256,256))split()mean, std = split.get_mean_std()print(f"mean: {mean}; std: {std}")print("====== execution completed ======")

      说明如下:

      (1).实现类名为SplitClassifyDataset,供外层调用;

      (2).接收的参数包括:源数据集目录(watermelon目录和wintermelon目录所在的目录);结果存放目录;拆分时训练集、验证集、测试集的比例;图像resize后的大小(图像不能变形,默认使用(114,114,114)填充);

      (3).图像会随机被拆分,即每次执行后结果图像会不同;

      (4).拆分后会计算训练集的均值和标准差

      以下为外层代码调用实现:

def split_dataset(src_dataset_path, dst_dataset_path, resize, ratios):split = SplitClassifyDataset(path_src=src_dataset_path, path_dst=dst_dataset_path, ratios=ast.literal_eval(ratios))# print("resize:", type(ast.literal_eval(resize))) # str to tuplesplit.resize(shape=ast.literal_eval(resize))split()mean, std = split.get_mean_std()print(f"mean: {mean}; std: {std}")

      执行后结果如下图所示:输出均值和标准差(后面训练和预测时都需要);新生成的目录组织结构(每个目录下存放一类)满足PyTorch的要求

      2.训练:

      (1).下载AlexNet预训练模型:仅第一次执行时会从网络下载

def load_pretraind_model():model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1) # the first execution will download model: alexnet-owt-7be5be79.pth, pos: C:\Users\xxxxxx/.cache\torch\hub\checkpoints\alexnet-owt-7be5be79.pth# print("model:", model)return model

      (2).加载拆分后的数据集:

def load_dataset(dataset_path, mean, std, labels_file):mean = ast.literal_eval(mean) # str to tuplestd = ast.literal_eval(std)# print(f"type: {type(mean)}, {type(std)}")train_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std),])train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)val_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std),])val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)write_labels(train_dataset.class_to_idx, labels_file)return len(train_dataset.class_to_idx), len(train_dataset), len(val_dataset), train_loader, val_loader

      (3).将对应的索引和标签写入文件:索引从0开始,依次加1;标签即拆分后的目录名;后面预测时会需要此文件

def write_labels(class_to_idx, labels_file):# print("class_to_idx:", class_to_idx)with open(labels_file, "w") as file:for key, val in class_to_idx.items():file.write("%d %s\n" % (int(val), key))

      (4).可视化训练过程中训练集和验证集的Loss和Accuracy:

def draw_graph(train_losses, train_accuracies, val_losses, val_accuracies):plt.subplot(1, 2, 1) # lossplt.title("Loss curve")plt.xlabel("Epoch Number")plt.ylabel("Loss")plt.plot(train_losses, color="blue")plt.plot(val_losses, color="red")plt.legend(["Train Loss", "Val Loss"])plt.subplot(1, 2, 2) # accuracyplt.title("Accuracy curve")plt.xlabel("Epoch Number")plt.ylabel("Accuracy")plt.plot(train_accuracies, color="blue")plt.plot(val_accuracies, color="red")plt.legend(["Train Accuracy", "Val Accuracy"])plt.show()

      某次的执行结果如下图所示:

      (5).主体代码:代码中加入了提前终止训练的判断条件;生成的最终模型名为best.pth

def train(dataset_path, epochs, mean, std, model_name, labels_file):classes_num, train_dataset_num, val_dataset_num, train_loader, val_loader = load_dataset(dataset_path, mean, std, labels_file)model = load_pretraind_model()in_features = model.classifier[6].in_features# print(f"in_features: {in_features}")model.classifier[6] = nn.Linear(in_features, classes_num) # modify the number of categoriesdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)optimizer = optim.Adam(model.parameters(), lr=0.0002) # set the optimizercriterion = nn.CrossEntropyLoss() # set the losstrain_losses = []train_accuracies = []val_losses = []val_accuracies = []highest_accuracy = 0.minimum_loss = 100.for epoch in range(epochs):# reference: https://learnopencv.com/image-classification-using-transfer-learning-in-pytorch/epoch_start = time.time()# print(colorama.Fore.CYAN + f"epoch: {epoch+1}/{epochs}")train_loss = 0.0 # losstrain_acc = 0.0 # accuracyval_loss = 0.0val_acc = 0.0model.train() # set to training modefor i, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)# print("inputs.size(0):", inputs.size(0))optimizer.zero_grad() # clean existing gradientsoutputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossloss.backward() # backpropagate the gradientsoptimizer.step() # update the parameterstrain_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute the accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floattrain_acc += acc.item() * inputs.size(0) # compute the total accuracy# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")model.eval() # set to evaluation modewith torch.no_grad():for i, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossval_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute validation accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floatval_acc += acc.item() * inputs.size(0) # compute the total accuracy# print(f"val batch number: {i}; validation loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")avg_train_loss = train_loss / train_dataset_num # average training lossavg_train_acc = train_acc / train_dataset_num # average training accuracyavg_val_loss = val_loss / val_dataset_num # average validation lossavg_val_acc = val_acc / val_dataset_num # average validation accuracytrain_losses.append(avg_train_loss)train_accuracies.append(avg_train_acc)val_losses.append(avg_val_loss)val_accuracies.append(avg_val_acc)epoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.4f}, accuracy:{avg_train_acc:.4f}; validation loss:{avg_val_loss:.4f}, accuracy:{avg_val_acc:.4f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(model.state_dict(), model_name)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.00001 and avg_val_acc > 0.99999:print(colorama.Fore.YELLOW + "stop training early")torch.save(model.state_dict(), model_name)breakdraw_graph(train_losses, train_accuracies, val_losses, val_accuracies)

      运行结果如下图所示:

      3.预测:

      (1).解析上面2.3中生成的文本文件:

def parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:# print(f"line: {line}")idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]# print(f"clases: {classes}; length: {len(classes)}")return classes

      (2).保存features:供后期使用,长度为9216

def get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef save_features(model, input_batch, image_name):features = model.features(input_batch) # shape: torch.Size([1, 256, 6, 6])features = model.avgpool(features)features = torch.flatten(features, 1) # shape: torch.Size([1, 9216])if torch.cuda.is_available():features = features.squeeze().detach().cpu().numpy() # shape: (9216,)else:features = features.queeeze().detach().numpy()# print(f"features: {features}; shape: {features.shape}")dir_name = "tmp"if not os.path.exists(dir_name):os.makedirs(dir_name)file_name = Path(image_name)file_name = file_name.name# print(f"file name: {file_name}")features.tofile(dir_name+"/"+file_name+".bin")

      (3).主体代码:

def predict(model_name, labels_file, images_path, mean, std):classes = parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"image_names = get_images_list(images_path)assert len(image_names) != 0, "no images found"mean = ast.literal_eval(mean) # str to tuplestd = ast.literal_eval(std)model = models.alexnet(weights=None)in_features = model.classifier[6].in_featuresmodel.classifier[6] = nn.Linear(in_features, len(classes)) # modify the number of categories# print("alexnet model:", model)model.load_state_dict(torch.load(model_name))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)print("image name\t\t\t\t\t\tclass\tprobability")model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)# print(f"output.shape: {output.shape}")probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on itmax_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name}\t\t\t\t\t\t{classes[max_index.item()]}\t{max_value.item():.4f}")save_features(model, input_batch, image_name)

      执行结果如下图所示:由结果可知,虽然数据集很少,训练次数也很少,但预测时能百分百预测准确

      支持的输入参数及入口函数如下图所示:

def parse_args():parser = argparse.ArgumentParser(description="AlexNet image classification")parser.add_argument("--task", required=True, type=str, choices=["split", "train", "predict"], help="specify what kind of task")parser.add_argument("--src_dataset_path", type=str, help="source dataset path")parser.add_argument("--dst_dataset_path", type=str, help="the path of the destination dataset after split")parser.add_argument("--resize", default=(256,256), help="the size to which images are resized when split the dataset")parser.add_argument("--ratios", default=(0.8,0.1,0.1), help="the ratio of split the data set(train set, validation set, test set), the test set can be 0, but their sum must be 1")parser.add_argument("--epochs", type=int, help="number of training")parser.add_argument("--mean", type=str, help="the mean of the training set of images")parser.add_argument("--std", type=str, help="the standard deviation of the training set of images")parser.add_argument("--model_name", type=str, help="the model generated during training or the model loaded during prediction")parser.add_argument("--labels_file", type=str, help="one category per line, the format is: index class_name")parser.add_argument("--predict_images_path", type=str, help="predict images path")args = parser.parse_args()return argsif __name__ == "__main__":colorama.init(autoreset=True)args = parse_args()if args.task == "split":# python test_alexnet.py --task split --src_dataset_path ../../data/database/classify/melon --dst_dataset_path datasets/melon_new_classify --resize (256,256) --ratios (0.7,0.2,0.1)split_dataset(args.src_dataset_path, args.dst_dataset_path, args.resize, args.ratios)elif args.task == "train":# python test_alexnet.py --task train --dst_dataset_path datasets/melon_new_classify --epochs 100 --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txttrain(args.dst_dataset_path, args.epochs, args.mean, args.std, args.model_name, args.labels_file)else: # predict# python test_alexnet.py --task predict --predict_images_path datasets/melon_new_classify/test --mean (0.52817206,0.60931162,0.59818634) --std (0.2533697287956878,0.22790271847362834,0.2380239874816262) --model_name best.pth --labels_file classes.txtpredict(args.model_name, args.labels_file, args.predict_images_path, args.mean, args.std)print(colorama.Fore.GREEN + "====== execution completed ======")

      GitHub:https://github.com/fengbingchun/NN_Test

这篇关于使用PyTorch AlexNet预训练模型对新数据集进行训练及预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1115945

相关文章

如何使用celery进行异步处理和定时任务(django)

《如何使用celery进行异步处理和定时任务(django)》文章介绍了Celery的基本概念、安装方法、如何使用Celery进行异步任务处理以及如何设置定时任务,通过Celery,可以在Web应用中... 目录一、celery的作用二、安装celery三、使用celery 异步执行任务四、使用celery

使用Python绘制蛇年春节祝福艺术图

《使用Python绘制蛇年春节祝福艺术图》:本文主要介绍如何使用Python的Matplotlib库绘制一幅富有创意的“蛇年有福”艺术图,这幅图结合了数字,蛇形,花朵等装饰,需要的可以参考下... 目录1. 绘图的基本概念2. 准备工作3. 实现代码解析3.1 设置绘图画布3.2 绘制数字“2025”3.3

详谈redis跟数据库的数据同步问题

《详谈redis跟数据库的数据同步问题》文章讨论了在Redis和数据库数据一致性问题上的解决方案,主要比较了先更新Redis缓存再更新数据库和先更新数据库再更新Redis缓存两种方案,文章指出,删除R... 目录一、Redis 数据库数据一致性的解决方案1.1、更新Redis缓存、删除Redis缓存的区别二

Jsoncpp的安装与使用方式

《Jsoncpp的安装与使用方式》JsonCpp是一个用于解析和生成JSON数据的C++库,它支持解析JSON文件或字符串到C++对象,以及将C++对象序列化回JSON格式,安装JsonCpp可以通过... 目录安装jsoncppJsoncpp的使用Value类构造函数检测保存的数据类型提取数据对json数

Redis事务与数据持久化方式

《Redis事务与数据持久化方式》该文档主要介绍了Redis事务和持久化机制,事务通过将多个命令打包执行,而持久化则通过快照(RDB)和追加式文件(AOF)两种方式将内存数据保存到磁盘,以防止数据丢失... 目录一、Redis 事务1.1 事务本质1.2 数据库事务与redis事务1.2.1 数据库事务1.

python使用watchdog实现文件资源监控

《python使用watchdog实现文件资源监控》watchdog支持跨平台文件资源监控,可以检测指定文件夹下文件及文件夹变动,下面我们来看看Python如何使用watchdog实现文件资源监控吧... python文件监控库watchdogs简介随着Python在各种应用领域中的广泛使用,其生态环境也

Python中构建终端应用界面利器Blessed模块的使用

《Python中构建终端应用界面利器Blessed模块的使用》Blessed库作为一个轻量级且功能强大的解决方案,开始在开发者中赢得口碑,今天,我们就一起来探索一下它是如何让终端UI开发变得轻松而高... 目录一、安装与配置:简单、快速、无障碍二、基本功能:从彩色文本到动态交互1. 显示基本内容2. 创建链

springboot整合 xxl-job及使用步骤

《springboot整合xxl-job及使用步骤》XXL-JOB是一个分布式任务调度平台,用于解决分布式系统中的任务调度和管理问题,文章详细介绍了XXL-JOB的架构,包括调度中心、执行器和Web... 目录一、xxl-job是什么二、使用步骤1. 下载并运行管理端代码2. 访问管理页面,确认是否启动成功

使用Nginx来共享文件的详细教程

《使用Nginx来共享文件的详细教程》有时我们想共享电脑上的某些文件,一个比较方便的做法是,开一个HTTP服务,指向文件所在的目录,这次我们用nginx来实现这个需求,本文将通过代码示例一步步教你使用... 在本教程中,我们将向您展示如何使用开源 Web 服务器 Nginx 设置文件共享服务器步骤 0 —

Java中switch-case结构的使用方法举例详解

《Java中switch-case结构的使用方法举例详解》:本文主要介绍Java中switch-case结构使用的相关资料,switch-case结构是Java中处理多个分支条件的一种有效方式,它... 目录前言一、switch-case结构的基本语法二、使用示例三、注意事项四、总结前言对于Java初学者