本文主要是介绍飞桨告诉我谁是最勤劳的小蜜蜂,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
谁是最勤劳的小蜜蜂
蜜蜂生来就是为了采蜜,它也是很辛勤,不是在采蜜就是已经采完蜜了,要不然是在蜂巢里吐蜜,直到奉献自己的一生。据悉勤劳的密封翅膀都带花粉的,那我们就来瞅瞅吧。
反正我肉眼看不清他们翅膀、小腿上的花粉,那么就交给机器来解决吧,机器学习嘛。
import os
import zipfile
import random
import json
import paddle
import sys
import numpy as np
from PIL import Image
from PIL import ImageEnhance
import paddle
import matplotlib.pyplot as plt
%matplotlib inline
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
一、数据准备
- (1)解压原始数据集
- (2)按照比例划分训练集与验证集
- (3)乱序,生成数据列表
- (4)构造训练数据集提供器和验证数据集提供器
# 已解压可注释
# !unzip -q data/data71008/花粉数据集archive.zip -d dataset
import paddle
import paddle.vision.transforms as T
import numpy as np
from PIL import Imageclass BeeDataset(paddle.io.Dataset):"""2类Bee数据集类的定义"""def __init__(self,mode='train',rate=0.2):"""初始化函数"""self.all_data = []self.data = []with open('dataset/PollenDataset/pollen_data.csv') as f:next(f)for line in f.readlines():info = line.strip().split(',')if len(info) > 0:self.all_data.append([info[1].strip(), info[2].strip()])self.transforms = T.Compose([T.Resize((64,64)), # (h,w) 180,300 图片缩放T.ToTensor(), # 数据的格式转换和标准化、 HWC => CHW T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])if mode=='train':self.data=self.all_data[int(len(self.all_data)*rate):len(self.all_data)]else:self.data=self.all_data[0:int(len(self.all_data)*rate)]def get_origin_data(self):return self.datadef __getitem__(self, index):"""根据索引获取单个样本"""image_file, label = self.data[index]image_file=os.path.join('dataset/PollenDataset/images',image_file)image = Image.open(image_file)if image.mode != 'RGB':image = image.convert('RGB')image = self.transforms(image)return image, np.array(label, dtype='int64')def __len__(self):"""获取样本总数"""return len(self.data)
bee=BeeDataset()
train_dataset=BeeDataset(mode='train',rate=0.3)
test_dataset=BeeDataset(mode='test',rate=0.3)
print('train_data len: {}, test_data len:{}'.format(train_dataset.__len__(), test_dataset.__len__()))
train_data len: 500, test_data len:214
二、模型配置
复习下卷积嘛,用啥摩托车,手写好了。
#定义卷积网络
import paddle.nn as nn
import paddle.nn.functional as F
from visualdl import LogWriterclass MyCNN(nn.Layer):def __init__(self):super(MyCNN,self).__init__()self.hidden1 = nn.Conv2D(in_channels=3, #通道数out_channels=64, #卷积核个数kernel_size =3, #卷积核大小stride=1) #步长self.hidden2 = nn.Conv2D(in_channels=64,out_channels = 128,kernel_size =3,stride=1)self.hidden3 = nn.MaxPool2D(kernel_size=2, #池化核大小stride=2) #池化步长self.hidden4 = nn.Linear(in_features=128*30*30,out_features=2)#网络的前向计算过程def forward(self,input):x = self.hidden1(input)with LogWriter(logdir="./chk_points/conv1/") as writer:for i in range(10):myimg=x[i][0:3].numpy()print(myimg.shape)print(myimg)writer.add_image(tag='conv1',img=myimg, step=i)x=F.relu(x)# print(x.shape)x = self.hidden2(x)x=F.relu(x)# print(x.shape)x = self.hidden3(x)# print(x.shape)#卷积层的输出特征图如何当作全连接层的输入使用呢?#卷积层的输出数据格式是[N,C,H,W],在输入全连接层的时候,会自动将数据拉平.#也就是对每个样本,自动将其转化为长度为K的向量,其中K=C×H×W,一个mini-batch的数据维度变成了N×K的二维向量。x = paddle.reshape(x, shape=[-1, 128*30*30])x = self.hidden4(x)out = F.softmax(x) return out
三、模型训练 && 四、模型评估
import paddle
from paddle import Model
myCNN=MyCNN()
model= Model(myCNN)
model.summary((1,3, 64, 64))
(3, 62, 62)
[[[ 0.7852405 1.1204937 1.0803782 ... 0.79842985 0.68843880.8554362 ][ 0.67254984 1.4634426 1.4180253 ... 1.1190094 0.73308060.36009166][ 0.77952844 1.25177 0.62106866 ... 0.781536 0.54997471.185901 ]...[ 1.1787175 1.165362 0.4543565 ... 0.45226488 1.70400011.3579762 ][ 1.3305937 0.9702562 0.6092412 ... 0.56712335 1.63293650.7801384 ][ 0.5835558 0.7375888 1.1013889 ... 1.4835536 0.84016542.0126145 ]][[-1.8209429 -1.010055 -2.013126 ... -0.6853134 -1.2335591-1.0887115 ][-1.1185786 -1.1789582 -1.3160777 ... -1.0062238 -1.5497901-0.7702041 ][-1.0398924 -1.2710589 -0.8835699 ... -1.412867 -0.75782824-0.66315925]...[-0.45574644 -0.63937426 -0.54133946 ... -0.7855961 -0.99309015-0.50211173][-1.2958944 -1.4259287 -1.220675 ... -0.6283496 -0.8967111-0.879598 ][-0.4558576 -0.9036231 -0.61986226 ... -1.2891432 -0.8856837-1.3334211 ]][[-0.16028228 -1.4461167 -0.8729783 ... 0.03628984 0.48390952-0.23797147][-0.1558084 -0.35822743 -0.15980428 ... -0.86909014 0.12811811-0.19449231][-1.0637001 -0.90246856 -0.6251344 ... 0.20142564 -0.9001509-0.6823548 ]...[-0.49504218 0.37776464 -0.37616605 ... -0.1840626 -0.74749-0.11741641][-0.7924826 -0.7401607 -0.12474468 ... -0.9140016 -0.7107928-0.23936886][-1.1824067 -0.1575087 -0.43380502 ... -0.540684 -0.65324664-0.5257259 ]]]---------------------------------------------------------------------------error Traceback (most recent call last)<ipython-input-40-198a0f8bb235> in <module>3 myCNN=MyCNN()4 model= Model(myCNN)
----> 5 model.summary((1,3, 64, 64))/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in summary(self, input_size, dtype)1879 else:1880 _input_size = self._inputs
-> 1881 return summary(self.network, _input_size, dtype)1882 1883 def _verify_spec(self, specs, shapes=None, dtypes=None, is_input=False):/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary(net, input_size, dtypes)147 148 _input_size = _check_input(_input_size)
--> 149 result, params_info = summary_string(net, _input_size, dtypes)150 print(result)151 </opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/decorator.py:decorator-gen-342> in summary_string(model, input_size, dtypes)/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py in _decorate_function(func, *args, **kwargs)313 def _decorate_function(func, *args, **kwargs):314 with self:
--> 315 return func(*args, **kwargs)316 317 @decorator.decorator/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary_string(model, input_size, dtypes)274 275 # make a forward pass
--> 276 model(*x)277 278 # remove these hooks/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)889 self._built = True890
--> 891 outputs = self.forward(*inputs, **kwargs)892 893 for forward_post_hook in self._forward_post_hooks.values():<ipython-input-39-e30af9085378> in forward(self, input)26 print(myimg.shape)27 print(myimg)
---> 28 writer.add_image(tag='conv1',img=myimg, step=i)29 x=F.relu(x)30 # print(x.shape)/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/visualdl/writer/writer.py in add_image(self, tag, img, step, walltime, dataformats)191 self._get_file_writer().add_record(192 image(tag=tag, image_array=img, step=step, walltime=walltime,
--> 193 dataformats=dataformats))194 195 def add_text(self, tag, text_string, step=None, walltime=None):/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/visualdl/component/base_component.py in image(tag, image_array, step, walltime, dataformats)169 image_array = denormalization(image_array)170 image_array = convert_to_HWC(image_array, dataformats)
--> 171 image_bytes = imgarray2bytes(image_array)172 image = Record.Image(encoded_image_string=image_bytes)173 return Record(values=[/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/visualdl/component/base_component.py in imgarray2bytes(np_array)70 import cv271
---> 72 np_array = cv2.cvtColor(np_array, cv2.COLOR_BGR2RGB)73 ret, buf = cv2.imencode(".png", np_array)74 img_bin = Image.fromarray(np.uint8(buf)).tobytes("raw")error: OpenCV(4.1.1) /io/opencv/modules/imgproc/src/color.simd_helpers.hpp:92: error: (-2:Unspecified error) in function 'cv::impl::{anonymous}::CvtHelper<VScn, VDcn, VDepth, sizePolicy>::CvtHelper(cv::InputArray, cv::OutputArray, int) [with VScn = cv::impl::{anonymous}::Set<3, 4>; VDcn = cv::impl::{anonymous}::Set<3, 4>; VDepth = cv::impl::{anonymous}::Set<0, 2, 5>; cv::impl::{anonymous}::SizePolicy sizePolicy = (cv::impl::<unnamed>::SizePolicy)2u; cv::InputArray = const cv::_InputArray&; cv::OutputArray = const cv::_OutputArray&]'
> Invalid number of channels in input image:
> 'VScn::contains(scn)'
> where
> 'scn' is 62
# 模型训练配置
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.00001,parameters=model.parameters()),# 优化器loss=paddle.nn.CrossEntropyLoss(), # 损失函数metrics=paddle.metric.Accuracy()) # 评估指标
# 训练可视化VisualDL工具的回调函数
visualdl = paddle.callbacks.VisualDL(log_dir='visualdl_log')
# 启动模型全流程训练
model.fit(train_dataset, # 训练数据集# test_dataset, # 评估数据集epochs=100, # 总的训练轮次batch_size=256, # 批次计算的样本量大小shuffle=True, # 是否打乱样本集verbose=1, # 日志展示格式save_dir='./chk_points/', # 分阶段的训练模型存储路径callbacks=[visualdl]) # 回调函数使用
###模型存储
将我们训练得到的模型进行保存,以便后续评估和测试使用。
model.evaluate(eval_data=test_dataset, verbose=2)
model.save('model_save_dir')
五、模型预测
print('测试数据集样本量:{}'.format(len(test_dataset)))
# 执行预测
result = model.predict(test_dataset)
# 样本映射
LABEL_MAP = ['偷懒的小蜜蜂','勤劳的小蜜蜂']
# 随机取样本展示
indexs = [2, 38, 56, 92, 100, 101]for idx in range(test_dataset.__len__()):predict_label = np.argmax(result[0][idx])real_label = test_dataset.__getitem__(idx)[1]print('样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, LABEL_MAP[real_label], LABEL_MAP[predict_label]))
origin_data=test_dataset.get_origin_data()
print(origin_data.__len__())
# 定义画图方法
from PIL import Image
import matplotlib.font_manager as font_managerfontpath = 'SIMHEI.TTF'
font = font_manager.FontProperties(fname=fontpath, size=10)def show_img(img, predict):plt.figure()plt.title(predict, FontProperties=font)plt.imshow(img, cmap=plt.cm.binary)plt.show()# 抽样展示
for i in range(10):img_path='dataset/PollenDataset/images/' + origin_data[i][0]real_label=int(origin_data[i][1])predict_label= int(np.argmax(result[0][i]))img=Image.open(img_path)title='样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, LABEL_MAP[real_label], LABEL_MAP[predict_label])show_img(img, title)
format(idx, LABEL_MAP[real_label], LABEL_MAP[predict_label])show_img(img, title)
奥奥,看起来小蜜蜂们都很勤劳啊,哈哈哈。
https://aistudio.baidu.com/aistudio/projectdetail/1549057
这篇关于飞桨告诉我谁是最勤劳的小蜜蜂的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!