27、ResNet50处理STEW数据集,用于情感三分类+全备的代码

2023-12-22 15:44

本文主要是介绍27、ResNet50处理STEW数据集,用于情感三分类+全备的代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、数据介绍

IEEE-Datasets-STEW:SIMULTANEOUS TASK EEG WORKLOAD DATASET :

该数据集由48名受试者的原始EEG数据组成,他们参加了利用SIMKAP多任务测试进行的多任务工作负荷实验。受试者在休息时的大脑活动也在测试前被记录下来,也包括在其中。Emotiv EPOC设备,采样频率为128Hz,有14个通道,用于获取数据,每个案例都有2.5分钟的EEG记录。受试者还被要求在每个阶段后以1到9的评分标准对其感知的心理工作量进行评分,评分结果在单独的文件中提供。

说明:每个受试者的数据遵循命名惯例:subno_task.txt。例如,sub01_lo.txt将是受试者1在休息时的原始脑电数据,而sub23_hi.txt将是受试者23在多任务测试中的原始脑电数据。每个数据文件的行对应于记录中的样本,列对应于EEG设备的14个通道: AF3, F7, F3, FC5, T7, P7, O1, O2, P8, T8, FC6, F4, F8, AF4。

数据说明、下载地址:

STEW: Simultaneous Task EEG Workload Data Set | IEEE Journals & Magazine | IEEE Xplore

2、代码

本次使用ResNet50,去做此情感数据的分类工作,数据导入+模型训练+测试代码如下:

import torch
import torchvision.datasets
from torch.utils.data import Dataset        # 继承Dataset类
import os
from PIL import Image
import numpy as np
from torchvision import transforms# 预处理
data_transform = transforms.Compose([transforms.Resize((224,224)),           # 缩放图像transforms.ToTensor(),                  # 转为Tensotransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))       # 标准化
])path =  r'C:\STEW\test'for root,dirs,files in os.walk(path):print('root',root) #遍历到该目录地址print('dirs',dirs) #遍历到该目录下的子目录名 []print('files',files)  #遍历到该目录下的文件  []def read_txt_files(path):# 创建文件名列表file_names = []# 遍历给定目录及其子目录下的所有文件for root, dirs, files in os.walk(path):# 遍历所有文件for file in files:# 如果是 .txt 文件,则加入文件名列表if file.endswith('.txt'): # endswith () 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。file_names.append(os.path.join(root, file))# 返回文件名列表return file_namesclass DogCat(Dataset):      # 数据处理def __init__(self,root,transforms = None):                  # 初始化,指定路径,是否预处理等等#['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']imgs = os.listdir(root)self.imgs = [os.path.join(root,img) for img in imgs]    # 取出root下所有的文件self.transforms = data_transform                        # 图像预处理def __getitem__(self, index):       # 读取图片img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0 #然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0.data = Image.open(img_path)if self.transforms:     # 图像预处理data = self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)dataset = DogCat('./data/',transforms=True)for img,label in dataset:print('img:',img.size(),'label:',label)
'''
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 0
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
img: torch.Size([3, 224, 224]) label: 1
'''import os# 获取file_path路径下的所有TXT文本内容和文件名
def get_text_list(file_path):files = os.listdir(file_path)text_list = []for file in files:with open(os.path.join(file_path, file), "r", encoding="UTF-8") as f:text_list.append(f.read())return text_list, filesclass ImageFolderCustom(Dataset):# 2. Initialize with a targ_dir and transform (optional) parameterdef __init__(self, targ_dir: str, transform=None) -> None:# 3. Create class attributes# Get all image pathsself.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's# Setup transformsself.transform = transform# Create classes and class_to_idx attributesself.classes, self.class_to_idx = find_classes(targ_dir)# 4. Make function to load imagesdef load_image(self, index: int) -> Image.Image:"Opens an image via a path and returns it."image_path = self.paths[index]return Image.open(image_path) # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)def __len__(self) -> int:"Returns the total number of samples."return len(self.paths)# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:"Returns one sample of data, data and label (X, y)."img = self.load_image(index)class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpegclass_idx = self.class_to_idx[class_name]# Transform if necessaryif self.transform:return self.transform(img), class_idx # return data, label (X, y)else:return img, class_idx # return data, label (X, y)import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transformsos.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"# cifar-10进行测验class Cutout(object):"""Randomly mask out one or more patches from an image.Args:n_holes (int): Number of patches to cut out of each image.length (int): The length (in pixels) of each square patch."""def __init__(self, n_holes, length):self.n_holes = n_holesself.length = lengthdef __call__(self, img):"""Args:img (Tensor): Tensor image of size (C, H, W).Returns:Tensor: Image with n_holes of dimension length x length cut out of it."""h = img.size(1)w = img.size(2)mask = np.ones((h, w), np.float32)for n in range(self.n_holes):y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img = img * maskreturn imgdef load_data_cifar10(batch_size=128,num_workers=2):# 操作合集# Data augmentationtrain_transform_1 = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转transforms.ToTensor(),transforms.Normalize((0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)),Cutout(1, 16),  # 务必放在ToTensor的后面])train_transform_2 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])test_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])# 训练集1trainset1 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_1,)# 训练集2trainset2 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_2,)# 测试集testset = tv.datasets.CIFAR10(root='data',train=False,download=False,transform=test_transform,)# 训练数据加载器1trainloader1 = torch.utils.data.DataLoader(trainset1,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 训练数据加载器2trainloader2 = torch.utils.data.DataLoader(trainset2,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 测试数据加载器testloader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))return trainloader1,trainloader2,testloaderdef main():start = time.time()batch_size = 128cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)model = resnet50().cuda()# model.load_state_dict(torch.load('_ResNet50.pth'))# 存在已保存的参数文件# model = nn.DataParallel(model,device_ids=[0,])  # 又套一层model = nn.DataParallel(model,device_ids=[0,1,2])loss = nn.CrossEntropyLoss().cuda()optimizer = optim.Adam(model.parameters(),lr=0.001)for epoch in range(50):model.train()  # 训练时务必写loss_=0.0num=0.0# train on trainloader1(data augmentation) and trainloader2for i,data in enumerate(cifar_train1,0):x, label = datax, label = x.cuda(),label.cuda()# xp = model(x) #outputl = loss(p,label) #lossoptimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num+=1for i, data in enumerate(cifar_train2, 0):x, label = datax, label = x.cuda(), label.cuda()# xp = model(x)l = loss(p, label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num += 1model.eval()  # 评估时务必写print("loss:",float(loss_)/num)# test on trainloader2,testloaderwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_train2:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_1 = total_correct / total_num# Testwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x) #output# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_2 = total_correct / total_numprint(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)# 保存时只保存model.moduletorch.save(model.module.state_dict(),'resnet50.pth')print("The interval is :",time.time() - start)if __name__ == '__main__':main()

3、对你有帮助的话,给个关注吧~

这篇关于27、ResNet50处理STEW数据集,用于情感三分类+全备的代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/524525

相关文章

uniapp接入微信小程序原生代码配置方案(优化版)

uniapp项目需要把微信小程序原生语法的功能代码嵌套过来,无需把原生代码转换为uniapp,可以配置拷贝的方式集成过来 1、拷贝代码包到src目录 2、vue.config.js中配置原生代码包直接拷贝到编译目录中 3、pages.json中配置分包目录,原生入口组件的路径 4、manifest.json中配置分包,使用原生组件 5、需要把原生代码包里的页面修改成组件的方

公共筛选组件(二次封装antd)支持代码提示

如果项目是基于antd组件库为基础搭建,可使用此公共筛选组件 使用到的库 npm i antdnpm i lodash-esnpm i @types/lodash-es -D /components/CommonSearch index.tsx import React from 'react';import { Button, Card, Form } from 'antd'

17.用300行代码手写初体验Spring V1.0版本

1.1.课程目标 1、了解看源码最有效的方式,先猜测后验证,不要一开始就去调试代码。 2、浓缩就是精华,用 300行最简洁的代码 提炼Spring的基本设计思想。 3、掌握Spring框架的基本脉络。 1.2.内容定位 1、 具有1年以上的SpringMVC使用经验。 2、 希望深入了解Spring源码的人群,对 Spring有一个整体的宏观感受。 3、 全程手写实现SpringM

【服务器运维】MySQL数据存储至数据盘

查看磁盘及分区 [root@MySQL tmp]# fdisk -lDisk /dev/sda: 21.5 GB, 21474836480 bytes255 heads, 63 sectors/track, 2610 cylindersUnits = cylinders of 16065 * 512 = 8225280 bytesSector size (logical/physical)

代码随想录算法训练营:12/60

非科班学习算法day12 | LeetCode150:逆波兰表达式 ,Leetcode239: 滑动窗口最大值  目录 介绍 一、基础概念补充: 1.c++字符串转为数字 1. std::stoi, std::stol, std::stoll, std::stoul, std::stoull(最常用) 2. std::stringstream 3. std::atoi, std

记录AS混淆代码模板

开启混淆得先在build.gradle文件中把 minifyEnabled false改成true,以及shrinkResources true//去除无用的resource文件 这些是写在proguard-rules.pro文件内的 指定代码的压缩级别 -optimizationpasses 5 包明不混合大小写 -dontusemixedcaseclassnames 不去忽略非公共

雨量传感器的分类和选型建议

物理原理分类 机械降雨量计(雨量桶):最早使用的降雨量传感器,通过漏斗收集雨水并记录。主要用于长期降雨统计,故障率较低。电容式降雨量传感器:基于两个电极之间的电容变化来计算降雨量。当降雨时,水滴堵住电极空间,改变电容值,从而计算降雨量。超声波式降雨量传感器:利用超声波的反射来计算降雨量。适用于大降雨量的场合。激光雷达式降雨量传感器:利用激光技术测量雨滴的速度、大小和形状等参数,并计算降雨量。主

SQL Server中,查询数据库中有多少个表,以及数据库其余类型数据统计查询

sqlserver查询数据库中有多少个表 sql server 数表:select count(1) from sysobjects where xtype='U'数视图:select count(1) from sysobjects where xtype='V'数存储过程select count(1) from sysobjects where xtype='P' SE

时间服务器中,适用于国内的 NTP 服务器地址,可用于时间同步或 Android 加速 GPS 定位

NTP 是什么?   NTP 是网络时间协议(Network Time Protocol),它用来同步网络设备【如计算机、手机】的时间的协议。 NTP 实现什么目的?   目的很简单,就是为了提供准确时间。因为我们的手表、设备等,经常会时间跑着跑着就有误差,或快或慢的少几秒,时间长了甚至误差过分钟。 NTP 服务器列表 最常见、熟知的就是 www.pool.ntp.org/zo

麻了!一觉醒来,代码全挂了。。

作为⼀名程序员,相信大家平时都有代码托管的需求。 相信有不少同学或者团队都习惯把自己的代码托管到GitHub平台上。 但是GitHub大家知道,经常在访问速度这方面并不是很快,有时候因为网络问题甚至根本连网站都打不开了,所以导致使用体验并不友好。 经常一觉醒来,居然发现我竟然看不到我自己上传的代码了。。 那在国内,除了GitHub,另外还有一个比较常用的Gitee平台也可以用于