MNIST3_tf2.keras训练预测

2023-11-02 14:32
文章标签 训练 预测 keras mnist3 tf2

本文主要是介绍MNIST3_tf2.keras训练预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

针对MNIST数据集进行j基于tf2.keras的nn模型训练和预测
部分脚本如下: 完整脚本见笔者github

import pandas as pd 
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Sequential, regularizers, Model
from utils.utils_tools import clock, get_ministdata
import warnings
warnings.filterwarnings(action='ignore')class MyModel(Model):def __init__(self):super(MyModel, self).__init__()self.fc0 = layers.Dense(256, activation='relu')self.bn1 = layers.BatchNormalization()self.fc1 = layers.Dense(32, activation='relu', kernel_regularizer = regularizers.l2(0.06))self.bn2 = layers.BatchNormalization()self.fc2 = layers.Dense(10, activation='sigmoid')def call(self, inputs_, training=None):x = self.fc0(inputs_)x = self.bn1(x)x = self.fc1(x)x = self.bn2(x)return self.fc2(x)def build_nnmodel(input_shape=(None, 784)):nnmodel = MyModel()nnmodel.build(input_shape=input_shape)nnmodel.compile(loss='categorical_crossentropy',optimizer = 'adam',metrics = ['accuracy'])return nnmodelif __name__ == '__main__':mnistdf = get_ministdata()te_index = mnistdf.sample(frac=0.75).index.tolist()mnist_te = mnistdf.loc[te_index, :]mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]mnist_tr_x, mnist_tr_y = mnist_tr.iloc[:, :-1].values, tf.keras.utils.to_categorical(mnist_tr.iloc[:, -1].values)mnist_te_x, mnist_te_y = mnist_te.iloc[:, :-1].values, tf.keras.utils.to_categorical(mnist_te.iloc[:, -1].values)nnmodel = build_nnmodel()stop = tf.keras.callbacks.EarlyStopping(monitor = 'accuracy', min_delta=0.0001)history = nnmodel.fit(mnist_tr_x, mnist_tr_y,validation_data = (mnist_te_x, mnist_te_y),epochs=10,callbacks = [stop])acc_final = round(max(history.history['accuracy']), 2)print(f"acc:{acc_final:.3f}")predict_ = np.argmax(nnmodel.predict(mnist_te_x), axis=1)te_y = np.argmax(mnist_te_y, axis=1)print(predict_)print(te_y)print(f'auc: {sum(predict_ == te_y)/te_y.shape[0]:.3f}')
"""
14000/14000 [==============================] - 10s 741us/sample - loss: 0.1472 - accuracy: 0.9730 - val_loss: 0.2162 - val_accuracy: 0.9557
Epoch 9/10
13952/14000 [============================>.] - ETA: 0s - loss: 0.1520 - accuracy: 0.9720
Epoch 00009: accuracy did not improve from 0.97300
14000/14000 [==============================] - 23s 2ms/sample - loss: 0.1522 - accuracy: 0.9719 - val_loss: 0.2390 - val_accuracy: 0.9506
acc:0.970
[2 2 8 ... 7 6 6]
[2 2 8 ... 7 6 6]
auc: 0.951, cost: 170.946"""

这篇关于MNIST3_tf2.keras训练预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/331271

相关文章

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录 在深度学习项目中,目标检测是一项重要的任务。本文将详细介绍如何使用Detectron2进行目标检测模型的复现训练,涵盖训练数据准备、训练命令、训练日志分析、训练指标以及训练输出目录的各个文件及其作用。特别地,我们将演示在训练过程中出现中断后,如何使用 resume 功能继续训练,并将我们复现的模型与Model Zoo中的

多云架构下大模型训练的存储稳定性探索

一、多云架构与大模型训练的融合 (一)多云架构的优势与挑战 多云架构为大模型训练带来了诸多优势。首先,资源灵活性显著提高,不同的云平台可以提供不同类型的计算资源和存储服务,满足大模型训练在不同阶段的需求。例如,某些云平台可能在 GPU 计算资源上具有优势,而另一些则在存储成本或性能上表现出色,企业可以根据实际情况进行选择和组合。其次,扩展性得以增强,当大模型的规模不断扩大时,单一云平

神经网络训练不起来怎么办(零)| General Guidance

摘要:模型性能不理想时,如何判断 Model Bias, Optimization, Overfitting 等问题,并以此着手优化模型。在这个分析过程中,我们可以对Function Set,模型弹性有直观的理解。关键词:模型性能,Model Bias, Optimization, Overfitting。 零,领域背景 如果我们的模型表现较差,那么我们往往需要根据 Training l

Tensorflow lstm实现的小说撰写预测

最近,在研究深度学习方面的知识,结合Tensorflow,完成了基于lstm的小说预测程序demo。 lstm是改进的RNN,具有长期记忆功能,相对于RNN,增加了多个门来控制输入与输出。原理方面的知识网上很多,在此,我只是将我短暂学习的tensorflow写一个预测小说的demo,如果有错误,还望大家指出。 1、将小说进行分词,去除空格,建立词汇表与id的字典,生成初始输入模型的x与y d

临床基础两手抓!这个12+神经网络模型太贪了,免疫治疗预测、通路重要性、基因重要性、通路交互作用性全部拿下!

生信碱移 IRnet介绍 用于预测病人免疫治疗反应类型的生物过程嵌入神经网络,提供通路、通路交互、基因重要性的多重可解释性评估。 临床实践中常常遇到许多复杂的问题,常见的两种是: 二分类或多分类:预测患者对治疗有无耐受(二分类)、判断患者的疾病分级(多分类); 连续数值的预测:预测癌症病人的风险、预测患者的白细胞数值水平; 尽管传统的机器学习提供了高效的建模预测与初步的特征重

如何创建训练数据集

在 HuggingFace 上创建数据集非常方便,创建完成之后,通过 API 可以方便的下载并使用数据集,在 Google Colab 上进行模型调优,下载数据集速度非常快,本文通过 Dataset 库创建一个简单的训练数据集。 首先安装数据集依赖 HuggingFace datasetshuggingface_hub 创建数据集 替换为自己的 HuggingFace API key

【YOLO 系列】基于YOLOV8的智能花卉分类检测系统【python源码+Pyqt5界面+数据集+训练代码】

前言: 花朵作为自然界中的重要组成部分,不仅在生态学上具有重要意义,也在园艺、农业以及艺术领域中占有一席之地。随着图像识别技术的发展,自动化的花朵分类对于植物研究、生物多样性保护以及园艺爱好者来说变得越发重要。为了提高花朵分类的效率和准确性,我们启动了基于YOLO V8的花朵分类智能识别系统项目。该项目利用深度学习技术,通过分析花朵图像,自动识别并分类不同种类的花朵,为用户提供一个高效的花朵识别