天池雪浪制造AI挑战赛(初赛)

2024-03-08 02:10

本文主要是介绍天池雪浪制造AI挑战赛(初赛),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

第一次参加比赛,记录一下,我是直接使用迁移学习进行分类 采用vgg16

排名不高仅供参考

import pandas as pd
import torch
import numpy as np
from torch.autograd import Variable
import torchvision
from torchvision import transforms, models
import matplotlib.pyplot as plt
import torch.nn.functional as F 
import os
from sklearn import metrics
import syssystem = sys.platform #判断系统的,两个电脑上 路径不一样
if system == 'win32':os.chdir('input')
mode = 'train'  # train用来训练, test生成csv提交结果
# mode = 'test'print('mode = ' + mode)#这一块是pytorch自带的的载入文件夹图片
transformer = transforms.Compose([transforms.Resize((224, 224)),# transforms.CenterCrop(200),# transforms.RandomVerticalFlip(),# transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])train_data = {x: torchvision.datasets.ImageFolder(x, transform=transformer)for x in ['train', 'val']}print(train_data['train'].class_to_idx)
train_loader = {}
train_loader['train'] = torch.utils.data.DataLoader(train_data['train'],batch_size=10,shuffle=True)
train_loader['val'] = torch.utils.data.DataLoader(train_data['val'],batch_size=10,shuffle=True)print('train num is ' + str(len(train_data['train'])))
print('val num is ' + str(len(train_data['val'])))if os.listdir('models'): #恢复模型print('restrore the model')model = torch.load('my_model.pkl')
else:print('use vgg16 model')# model = torch.load('vgg16.pkl') #因为网络不好, 我都是提前下下来保存再载入# model = torch.load('vgg_11_bn.pkl')# models.vgg16_bn(pretrained=True, batch_norm)model.classifier = torch.nn.Sequential(torch.nn.Linear(7*7*512, 2), #vgg提取特征不变  分类层改一下  if torch.cuda.is_available(): #cpu gpu转换model = model.cuda()
print(model)loss_func = torch.nn.CrossEntropyLoss()
lr = 1e-5optimizer = torch.optim.Adam(model.parameters(), lr=lr)## 建立这些列表基本都是用来画图的
epochs = 30 
plot_loss = []
best_auc = 0
auc_list = []
auc_list2 = []
train_acc_list = []
test_acc_list = []
# plt.ion()def valling(dir_name, model):"""得到网络输出   用来metrics0 1标签(用来算正确率)概率(算auc)label"""model.eval()print('valling in ' + str(dir_name))y_pre_all = np.array(())test_y_all = np.array(())all_pro = np.array(())for tep_idx, [test_x, test_y] in enumerate(train_loader[dir_name]):if tep_idx <= 10:test_x, test_y = next(iter(train_loader[dir_name]))if torch.cuda.is_available():test_x, test_y = (test_x.cuda()), (test_y.cuda())y_out_test = model(test_x)all_pro = np.append(all_pro, F.softmax(y_out_test, 0).cpu().data.numpy()[:, 1])# print(y_out_test)y_pre_test = torch.argmax(y_out_test, 1)y_pre_test = y_pre_test.cpu().data.numpy()test_y = test_y.cpu().data.numpy()# print(y_pre_all.shape)# print(y_pre_test.shape)y_pre_all = np.append(y_pre_all, y_pre_test)test_y_all = np.append(test_y_all, test_y)# print(y_pre_all.shape)return y_pre_all, test_y_all, all_prodef my_metrics(pre, label, pro):'''计算auc  acc'''# print('label shape is ' + str(label.shape))# print('pro shape is ' + str(pro.shape))auc = metrics.roc_auc_score(label, pro)bool_arr_test = (pre == label) test_acc = np.sum(bool_arr_test) / pre.sizereturn auc, test_accdef plot_list(list1, list2, dir_, title):'''画图  train 和test的acc  auc'''abs_dir = os.path.abspath(dir_)if not os.path.exists(os.path.dirname(abs_dir)):os.mkdir(os.path.dirname(abs_dir))print('creat dir{}'.format(abs_dir))plt.figure()plt.plot(list1, label='train')plt.plot(list2, label='test')plt.title(title)plt.legend(loc='best')plt.savefig(dir_)plt.close()if mode =='train':best_acc = 0plot_epoch_loss = []# print(model)for epoch in range(epochs):model.train()print('training')batch = 0epoch_loss = 0correct = 0# print(train_loader['train'])for data in train_loader['train']:batch += 1x, y = dataif torch.cuda.is_available():x, y = x.cuda(), y.cuda()x, y = Variable(x), Variable(y)y_out = model(x)optimizer.zero_grad()loss = loss_func(y_out, y)epoch_loss += loss# print(loss.data)# print(loss.data[0])loss.backward()optimizer.step()a_loss = loss.cpu().data.numpy()plot_loss.append(a_loss)plt.cla()plt.plot(plot_loss)print(a_loss)plt.text(0, 0.5, 'loss = %.3f' % a_loss, {'color': 'red', 'size': 15})plt.savefig('loss2.png')plt.close()plt.pause(0.5)y_pre_all, test_y_all, all_pro = valling('val', model)train_y_pre_all, train_test_y_all, train_all_pro = valling('train', model)auc, test_acc = my_metrics(y_pre_all, test_y_all, all_pro)train_auc, train_test_acc = my_metrics(train_y_pre_all, train_test_y_all, train_all_pro)train_acc_list.append(train_test_acc)test_acc_list.append(test_acc)saved_figs_dir = 'vgg11_full_32' plot_list(train_acc_list, test_acc_list, os.path.join('saved_figs', saved_figs_dir, 'acc.png'), 'acc_curve')auc_list.append(auc)auc_list2.append(train_auc)plot_list(auc_list2, auc_list, os.path.join('saved_figs', saved_figs_dir, 'auc.png'), 'auc_curve')best_acc = max(best_acc, test_acc) #保存最好的结果best_auc = max(best_auc, auc)print('test_acc = ' + str(test_acc * 100)[:4] + '%')print('train_acc = ' + str(train_test_acc * 100)[:4] + '%')epoch_loss = epoch_loss.cpu().data.numpy()print('This ' + str(epoch) + 'th epoch', 'epoch average loss = ' + str(epoch_loss/(batch)))plot_epoch_loss.append(epoch_loss / (batch))plt.figure()plt.plot(plot_epoch_loss)plt.title('epoch_loss')plt.savefig(os.path.join('saved_figs', saved_figs_dir, 'epoch_loss.png'))# plt.savefig('saved_figs/2/epoch_loss.png')print('lr = {}'.format(lr))if best_acc <= test_acc: #存正确率最高的模型# if best_auc <= auc:#存auc最高的print('score is better  store model')torch.save(model, 'models/my_model.pkl')else:print("not good don't save")print('-' * 40)    else:#用来生成提交结果test_data = torchvision.datasets.ImageFolder('test', transform=transformer)test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=10,shuffle=False)ret_df = pd.DataFrame(columns=['filename', 'probability'])filenames = []for i in test_data.imgs:filename = os.path.basename(i[0])filenames.append(filename)# print(filenames)ret_df['filename'] = filenamesfor i, [x, y] in enumerate(test_data_loader):if torch.cuda.is_available():x = x.cuda()x = Variable(x)pre_out = model(x)pro = F.softmax(pre_out).cpu().data.numpy()[:, 1]pro = np.clip(pro, 0.000001, 0.999999)print('The ' + str(i*10) + ' th ' + 'row')try:ret_df.iloc[10*i: 10*i+10, 1] = proexcept Exception:ret_df.loc[10*i:, 'probability'] = proret_df = ret_df.round(6)print((ret_df['probability'] <= 0).sum())print((ret_df['probability'] >= 1).sum())ret_df.to_csv('outputs/submission.csv', index=False, encoding='utf-8')

新人学习中

这篇关于天池雪浪制造AI挑战赛(初赛)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/785625

相关文章

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Spring AI ectorStore的使用流程

《SpringAIectorStore的使用流程》SpringAI中的VectorStore是一种用于存储和检索高维向量数据的数据库或存储解决方案,它在AI应用中发挥着至关重要的作用,本文给大家介... 目录一、VectorStore的基本概念二、VectorStore的核心接口三、VectorStore的

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI集成DeepSeek实现流式输出的操作方法

《SpringAI集成DeepSeek实现流式输出的操作方法》本文介绍了如何在SpringBoot中使用Sse(Server-SentEvents)技术实现流式输出,后端使用SpringMVC中的S... 目录一、后端代码二、前端代码三、运行项目小天有话说题外话参考资料前面一篇文章我们实现了《Spring

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

Spring AI集成DeepSeek的详细步骤

《SpringAI集成DeepSeek的详细步骤》DeepSeek作为一款卓越的国产AI模型,越来越多的公司考虑在自己的应用中集成,对于Java应用来说,我们可以借助SpringAI集成DeepSe... 目录DeepSeek 介绍Spring AI 是什么?1、环境准备2、构建项目2.1、pom依赖2.2

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

SpringBoot整合DeepSeek实现AI对话功能

《SpringBoot整合DeepSeek实现AI对话功能》本文介绍了如何在SpringBoot项目中整合DeepSeekAPI和本地私有化部署DeepSeekR1模型,通过SpringAI框架简化了... 目录Spring AI版本依赖整合DeepSeek API key整合本地化部署的DeepSeek