本文主要是介绍鸢尾花分类-pytorch实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
前言
本文用pytorch实现了鸢尾花分类,数据不多,只做代码展示用,后续有升级版本。
代码
'''
-*- coding: utf-8 -*-
@File : main.py
@Author: Shanmh
@Time : 2024/05/06 上午9:37
@Function:
'''
import torch
from sklearn import datasets
import torch.nn as nn#1.数据准备
dataset=datasets.load_iris()
print(dataset["data"][:10])
print(dataset["target"][:10])
i_data=torch.FloatTensor(dataset["data"])
i_target=torch.LongTensor(dataset["target"])#2.模型构建
class IrisModel(nn.Module):def __init__(self,input_n=4,hidden_n=20,output_n=3):super().__init__()self.line1=nn.Linear(input_n,hidden_n)self.line2=nn.Linear(hidden_n,output_n)self.relu=nn.ReLU()def forward(self,x):x=self.line1(x)x=self.relu(x)x=self.line2(x)return x#3.参数定义
epoch=500
lr=0.01model=IrisModel()
optimizer=torch.optim.SGD(model.parameters(),lr=lr) #定义优化器
loss_fun=torch.nn.CrossEntropyLoss() #多分类采用交叉熵损失函数for e in range(epoch):out=model(i_data)loss=loss_fun(out,i_target)optimizer.zero_grad() # 梯度清零loss.backward() # 前馈操作optimizer.step()# 5. 得出结果
out = model(i_data)
prediction = torch.max(out, 1)[1]
pred_y = prediction.data.numpy()
target_y = i_target.data.numpy()
result=pred_y==target_y
print(f"模型预测准确度,acc:{'{:.2f}'.format(len(result[result==True])/len(result))}%")
展望
1.还在考虑中怎么进行建模,建一个4维空间用来直接看出输入与输出的关系
2.有尝试过标签平滑,从结果上看不出什么区别,再想怎么可视化出来
3.怎么从结果倒推出可用的输入数据
这篇关于鸢尾花分类-pytorch实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!