利用pytorch两层线性网络对titanic数据集进行分类(kaggle)

2024-05-04 09:04

本文主要是介绍利用pytorch两层线性网络对titanic数据集进行分类(kaggle),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用pytorch两层线性网络对titanic数据集进行分类

最近在看pytorch的入门课程,做了一下在kaggle网站上的作业,用的是titanic数据集,因为想搭一下神经网络,所以数据加载部分简单的把训练集和测试集中有缺失值的列还有含有字符串的列去除了,加入了DataLoader模块,其实这个数据集很小,用不到,本人还没入门,小白一枚。

import torch 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets
from torchvision import transforms
import pandas as pdclass titanicDataset(Dataset):def __init__(self,filepath):xy=np.loadtxt(filepath,delimiter=',',skiprows=1,usecols=[1,2,7,8],dtype=np.float32)self.len=xy.shape[0]# print(self.len)self.y_data=torch.from_numpy(xy[:,[0]])self.x_data=torch.from_numpy(xy[:,1:])def __getitem__(self,index):#获取索引元素 return self.x_data[index],self.y_data[index]def __len__(self):return self.len
dataset=titanicDataset('./pytorch/dataset/titanic/train.csv')
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)# print(dataset.x_data,dataset.y_data)
test_loader=DataLoader(dataset=np.loadtxt('./pytorch/dataset/titanic/test.csv',delimiter=',',skiprows=1,usecols=[1,6,7],dtype=np.float32),batch_size=32,shuffle=False,num_workers=0)
print(next(iter(test_loader)))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()# self.linear1=torch.nn.Linear(4,3)self.linear2=torch.nn.Linear(3,2)self.linear3=torch.nn.Linear(2,1)self.sigmoid=torch.nn.Sigmoid()def forward(self,x):# x=self.sigmoid(self.linear1(x))x=self.sigmoid(self.linear2(x))x=self.sigmoid(self.linear3(x))return x
model=Model()
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
for epoch in range(10000):acc_num=0for i,data in enumerate(train_loader,0):#1.Prepare datainputs,labels=data# print(inputs.shape[0])#2.Forwardy_pred=model(inputs)loss=criterion(y_pred,labels)# print(epoch,i,loss.item())#3.Backwardoptimizer.zero_grad()loss.backward()#4.Updateoptimizer.step()y_pred_label=torch.where(y_pred>0.5,torch.tensor([1.0]),torch.tensor([0.0]))acc_num+=torch.eq(y_pred_label,labels).sum().item()# print(acc_num,len(dataset),len(train_loader.dataset))acc=acc_num/len(dataset)
print(acc)
# print(test_loader)
# print(test_loader.dataset.shape)
out = model(torch.tensor(test_loader.dataset))
y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))[:,0]
print(y_pred)
print(pd.Series(y_pred))
id=pd.read_csv('./pytorch/dataset/titanic/test.csv',usecols=['PassengerId']).iloc[:,0]
# print(type(id))pd.DataFrame({'PassengerId':id,'Survived':pd.Series(y_pred,dtype=int)}).to_csv('pred.csv',index=None)
a=pd.DataFrame([id,pd.Series(y_pred)])
print(a)
# print(y_pred[-10:])# for x in test_loader:
#     print(x.shape)
#     out = model(x)
#     y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))
# print(y_pred)

这篇关于利用pytorch两层线性网络对titanic数据集进行分类(kaggle)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/958834

相关文章

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

SpringBoot中使用 ThreadLocal 进行多线程上下文管理及注意事项小结

《SpringBoot中使用ThreadLocal进行多线程上下文管理及注意事项小结》本文详细介绍了ThreadLocal的原理、使用场景和示例代码,并在SpringBoot中使用ThreadLo... 目录前言技术积累1.什么是 ThreadLocal2. ThreadLocal 的原理2.1 线程隔离2

Python利用PIL进行图片压缩

《Python利用PIL进行图片压缩》有时在发送一些文件如PPT、Word时,由于文件中的图片太大,导致文件也太大,无法发送,所以本文为大家介绍了Python中图片压缩的方法,需要的可以参考下... 有时在发送一些文件如PPT、Word时,由于文件中的图片太大,导致文件也太大,无法发送,所有可以对文件中的图

Redis的数据过期策略和数据淘汰策略

《Redis的数据过期策略和数据淘汰策略》本文主要介绍了Redis的数据过期策略和数据淘汰策略,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录一、数据过期策略1、惰性删除2、定期删除二、数据淘汰策略1、数据淘汰策略概念2、8种数据淘汰策略

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

如何使用Spring boot的@Transactional进行事务管理

《如何使用Springboot的@Transactional进行事务管理》这篇文章介绍了SpringBoot中使用@Transactional注解进行声明式事务管理的详细信息,包括基本用法、核心配置... 目录一、前置条件二、基本用法1. 在方法上添加注解2. 在类上添加注解三、核心配置参数1. 传播行为(

Python给Excel写入数据的四种方法小结

《Python给Excel写入数据的四种方法小结》本文主要介绍了Python给Excel写入数据的四种方法小结,包含openpyxl库、xlsxwriter库、pandas库和win32com库,具有... 目录1. 使用 openpyxl 库2. 使用 xlsxwriter 库3. 使用 pandas 库

Java实战之自助进行多张图片合成拼接

《Java实战之自助进行多张图片合成拼接》在当今数字化时代,图像处理技术在各个领域都发挥着至关重要的作用,本文为大家详细介绍了如何使用Java实现多张图片合成拼接,需要的可以了解下... 目录前言一、图片合成需求描述二、图片合成设计与实现1、编程语言2、基础数据准备3、图片合成流程4、图片合成实现三、总结前

SpringBoot定制JSON响应数据的实现

《SpringBoot定制JSON响应数据的实现》本文主要介绍了SpringBoot定制JSON响应数据的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们... 目录前言一、如何使用@jsonView这个注解?二、应用场景三、实战案例注解方式编程方式总结 前言