【tensorflow 全连接神经网络】 minist 手写数字识别

2024-09-07 06:38

本文主要是介绍【tensorflow 全连接神经网络】 minist 手写数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

主要内容:
使用tensorflow构建一个三层全连接传统神经网络,作为字符识别的多分类器。通过字符图片预测对应的数字,对mnist数据集进行预测。

# coding: utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import mathmnist = input_data.read_data_sets("./mnist/",one_hot=True)print("Training set:",mnist.train.images.shape)
print("Training set labels:",mnist.train.labels.shape)
print("Dev Set(Cross Validation set):",mnist.validation.images.shape)
print("Dev Set labels:",mnist.validation.labels.shape)
print("Test Set:",mnist.test.images.shape)
print("Test Set labels:",mnist.test.labels.shape)x_train = mnist.train.images
y_train = mnist.train.labels
x_dev = mnist.validation.images
y_dev = mnist.validation.labels
x_test = mnist.test.images
y_test = mnist.test.labelsdef display_digit(index):print(y_train[index])label = y_train[index].argmax(axis=0)image = x_train[index].reshape([28,28])plt.title("Example: %d  Label: %d" % (index, label))plt.imshow(image, cmap=plt.get_cmap("gray_r"))plt.show()display_digit(5)
print(y_train[5].shape)#按照Andrew的建议把样本横向排列
x_train = x_train.T
y_train = y_train.T
x_dev = x_dev.T
y_dev = y_dev.T
x_test = x_test.T
y_test = y_test.T
print("x_train shape",x_train.shape)
print("y_train shape",y_train.shape)def random_mini_batches(X,Y,mini_batch_size=64):"""Creates a list of random minibatches from (X, Y)Arguments:X -- input data, of shape (input size, number of examples)Y -- true "label" vector (1 for blue dot / 0 for red dot), of shape (1, number of examples)mini_batch_size -- size of the mini-batches, integerReturns:mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)"""m = X.shape[1] #训练样本个数mini_batches = []# Step 1: Shuffle (X, Y)permutation = list(np.random.permutation(m))shuffled_X = X[:, permutation]shuffled_Y = Y[:, permutation].reshape((-1, m))# Step 2: Partition (shuffled_X, shuffled_Y). Minus the end case.num_complete_minibatches = math.floor(m / mini_batch_size)  # number of mini batches of size mini_batch_size in your partitionningfor k in range(0, num_complete_minibatches):mini_batch_X = shuffled_X[:, k * mini_batch_size:(k + 1) * mini_batch_size]mini_batch_Y = shuffled_Y[:, k * mini_batch_size:(k + 1) * mini_batch_size]mini_batch = (mini_batch_X, mini_batch_Y)mini_batches.append(mini_batch)# Handling the end case (last mini-batch < mini_batch_size)if m % mini_batch_size != 0:mini_batch_X = shuffled_X[:, mini_batch_size * num_complete_minibatches:]mini_batch_Y = shuffled_Y[:, mini_batch_size * num_complete_minibatches:]mini_batch = (mini_batch_X, mini_batch_Y)mini_batches.append(mini_batch)return mini_batches"参数初始化"
layer_dims = [784,64,128,10] #三层网络,hidden units个数为64,128,10   一共有10个类别  def init_parameters(layer_dims):parameters = {}L = len(layer_dims) - 1 # number of layers in the networkfor l in range(1,L+1):parameters["W"+str(l)] = tf.Variable(tf.random_normal([layer_dims[l], layer_dims[l-1]]))parameters["b"+str(l)] = tf.Variable(tf.random_normal([layer_dims[l],1]))return parameters    def forward_propagation(X, parameters):W1 = parameters['W1']b1 = parameters['b1']W2 = parameters['W2']b2 = parameters['b2']W3 = parameters['W3']b3 = parameters['b3']Z1 = tf.add(tf.matmul(W1, X), b1)  # Z1 = np.dot(W1, X) + b1A1 = tf.nn.relu(Z1)  # A1 = relu(Z1)Z2 = tf.add(tf.matmul(W2, A1), b2)  # Z2 = np.dot(W2, a1) + b2A2 = tf.nn.relu(Z2)  # A2 = relu(Z2)Z3 = tf.add(tf.matmul(W3, A2), b3)  # Z3 = np.dot(W3,Z2) + b3return Z3def compute_cost(Z3, Y):"""Computes the costArguments:Z3 -- output of forward propagation (output of the last LINEAR unit), of shape (10, number of examples)Y -- "true" labels vector placeholder, same shape as Z3Returns:cost - Tensor of the cost function"""# to fit the tensorflow requirement for tf.nn.softmax_cross_entropy_with_logits(...,...)logits = tf.transpose(Z3)labels = tf.transpose(Y)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))return costdef tf_nn_model(X_train,Y_train,X_test,Y_test,layer_dims,learning_rate=0.001,num_epochs=100,minibatch_size=64,print_cost=True):(n_x,m) = X_train.shape # (n_x: input size, m : number of examples in the train set)n_y = Y_train.shape[0] # n_y : output sizecosts = [] # to keep track of the costX = tf.placeholder(tf.float32, [n_x, None], name="X")Y = tf.placeholder(tf.float32, [n_y, None], name="Y")parameters = init_parameters(layer_dims)Z3 = forward_propagation(X, parameters)cost = compute_cost(Z3, Y)optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)init  = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)for epoch in range(num_epochs):epoch_cost = 0.  # Defines a cost related to an epochnum_minibatches = int(m / minibatch_size)minibatches = random_mini_batches(X_train, Y_train, minibatch_size)for minibatch in minibatches:(minibatch_X, minibatch_Y) = minibatch_, minibatch_cost = sess.run([optimizer, cost], feed_dict={X: minibatch_X, Y: minibatch_Y})epoch_cost += minibatch_cost / num_minibatchesif print_cost == True and epoch % 10 == 0:print("Cost after epoch %i: %f" % (epoch, epoch_cost))if print_cost == True and epoch % 5 == 0:costs.append(epoch_cost)   # plot the costplt.plot(np.squeeze(costs))plt.ylabel('cost')plt.xlabel('iterations (per tens)')plt.title("Learning rate =" + str(learning_rate))plt.show() parameters = sess.run(parameters)print("Parameters have been trained!")correct_prediction = tf.equal(tf.argmax(Z3), tf.argmax(Y))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print("Train Accuracy:", accuracy.eval({X: X_train, Y: Y_train}))print("Test Accuracy:", accuracy.eval({X: X_test, Y: Y_test}))return parameterstf_nn_model(x_train,y_train,x_test,y_test,layer_dims,learning_rate=0.001,num_epochs=100,minibatch_size=64,print_cost=True)

运行结果:

Cost after epoch 0: 75.913229
Cost after epoch 10: 1.541095
Cost after epoch 20: 0.436585
Cost after epoch 30: 0.174160
Cost after epoch 40: 0.090298
Cost after epoch 50: 0.064457
Cost after epoch 60: 0.044082
Cost after epoch 70: 0.035504
Cost after epoch 80: 0.022698
Cost after epoch 90: 0.023649Parameters have been trained!
Train Accuracy: 0.994545
Test Accuracy: 0.9427
Out[106]:
{'W1': array([[ 0.2372188 ,  1.27198195, -0.6455391 , ...,  1.26290512,-0.69059598,  0.36647785],[-0.50644702, -0.74370074,  0.38941762, ..., -0.15578361,-0.31009915, -0.17434931],[-2.5437634 ,  0.44527429, -0.70932513, ..., -1.01713133,-0.14752612,  0.19787782],..., [ 3.25048923,  0.08093037,  0.77567875, ..., -0.79534328,1.43014407,  0.21873565],[-1.93292856, -0.19783179,  0.12327723, ..., -0.22539552,0.13556184,  0.87210643],[-0.93210453,  0.2583403 ,  1.58626533, ..., -1.69557643,0.31096032,  0.41782433]], dtype=float32),'W2': array([[ 0.66262263, -0.41401526,  0.83104825, ..., -0.28790367,1.44923198, -0.01293663],[-0.94457793, -0.47847596,  0.39193049, ..., -0.44852871,0.31511024, -0.12879851],[ 0.83933985, -0.25525221,  1.83002853, ..., -0.7023285 ,0.29116887,  1.32396758],..., [-1.21769059,  0.21980943,  0.05707775, ..., -0.70724338,0.13368286, -0.47907224],[-0.78505909, -0.26749918, -1.0756464 , ...,  0.10546964,0.59970111, -0.47928923],[ 1.57277954,  0.20598291, -0.38545936, ..., -0.68153149,-0.01901394, -1.09839475]], dtype=float32),'W3': array([[ 0.23412205,  1.4664923 ,  1.02762878, ...,  0.13184339,1.05118167, -0.00358887],[ 0.26813394,  0.295957  ,  1.49240541, ...,  0.82661223,0.67465705, -0.32320595],[ 1.19123352, -0.83540916,  0.07576221, ..., -0.58284307,0.32790881,  0.13413283],..., [ 0.43964136,  1.74946868, -0.54555362, ..., -0.1613521 ,-0.37434128,  0.80795258],[ 0.60402709,  0.05262127,  0.42084417, ...,  0.47054997,-0.32987207, -1.64671504],[-0.78972542,  0.7970084 , -0.60551286, ...,  1.74413514,0.6057446 , -0.28617254]], dtype=float32),'b1': array([[-0.4571954 ],[-0.30936778],[-0.83330458],[-1.68725026],[-1.42897224],[-1.04096746],[-0.54966289],[ 2.43672371],[ 1.36083376],[-1.51412904],[-2.0457561 ],[-2.69589877],[-0.23028924],[ 0.88664472],[-1.48165977],[-2.08099437],[ 0.43034646],[ 0.7627002 ],[ 0.40478835],[-0.51313281],[-1.18395376],[-0.36716571],[-1.98513615],[-0.58582592],[-0.77087468],[-0.9414832 ],[ 0.25200051],[-0.98766547],[ 0.31909475],[ 0.0800764 ],[-0.01556224],[ 0.83097136],[ 0.32423681],[ 1.24688494],[-0.02111918],[-2.12303662],[-1.69796181],[ 0.68959635],[-0.6191389 ],[-1.28080022],[-0.17510706],[-0.23040138],[-0.46036553],[ 1.56836855],[ 2.0383904 ],[-0.86711407],[-1.19858789],[-1.96049547],[ 1.14845157],[-0.75677299],[-2.4980433 ],[ 0.13432245],[ 0.24774934],[-0.10357552],[ 0.93644065],[-1.22094846],[ 1.15299678],[ 1.51815248],[-0.20407377],[-0.76557356],[ 0.5967567 ],[ 1.13081288],[-0.34519741],[-0.18847673]], dtype=float32),'b2': array([[ 0.28188977],[ 1.13188219],[-0.51833898],[ 1.55272174],[ 0.3362346 ],[-0.62963486],[-0.55736727],[-1.99950421],[ 1.64439845],[ 0.09734726],[-2.69561672],[ 0.29041779],[ 0.72709852],[ 0.43301356],[-0.43779549],[-0.6581856 ],[-2.80175161],[-0.41372192],[-2.09087038],[-0.47786576],[ 0.31763604],[ 1.85912359],[ 1.59187448],[-1.36818421],[-0.65758836],[-0.12403597],[ 1.05362165],[-0.30393735],[ 1.8399303 ],[-0.29227388],[ 0.75677097],[ 0.3613534 ],[-0.18842472],[-0.66885817],[-0.27949655],[-0.89438319],[-1.51220632],[ 0.93994361],[-1.54467905],[-1.00363708],[-0.57895792],[-0.52491599],[ 2.27655602],[-0.85130656],[ 0.04630496],[ 1.12568331],[-0.38881832],[-0.27415273],[-0.86503613],[ 0.96864253],[-0.9870069 ],[ 0.37869945],[-1.68591571],[-0.62210619],[-0.01916602],[ 0.11517724],[-0.29602063],[-1.42557037],[ 1.11371112],[-1.10030782],[-0.23480549],[-0.83260995],[ 0.78863978],[-0.44784972],[ 0.18259326],[ 1.48195684],[-0.32906139],[-1.4134475 ],[ 0.52768463],[-0.46708786],[-1.52612662],[ 0.30641365],[-1.06699479],[-1.44061339],[-1.39849806],[-0.65535295],[-0.17019601],[ 0.86427599],[ 0.51089519],[ 0.63639545],[-0.31796476],[-0.96631444],[-1.21334612],[ 0.79893589],[ 0.90393507],[ 1.05157661],[-0.1798792 ],[ 0.35506439],[-0.88265395],[-0.77211195],[-0.35244057],[-0.97597492],[ 1.81438792],[ 1.50866187],[ 1.76945257],[-2.2490623 ],[ 1.27219939],[ 0.11137661],[-0.03369612],[ 1.64185321],[ 0.14421514],[ 1.1957972 ],[ 0.10298974],[-1.63592625],[ 1.57520294],[-2.0683074 ],[-0.78121209],[-0.02082653],[ 0.88429558],[ 0.98407972],[-1.09006429],[ 0.44493109],[-1.88774467],[-2.0510056 ],[-1.04833782],[ 1.08415902],[-1.55531442],[-1.52134264],[ 0.23356596],[-0.70101881],[-0.25792068],[ 0.41581729],[-0.11349884],[-3.29242682],[-0.68287402],[ 1.45735371],[ 0.07658232],[-0.82881683]], dtype=float32),'b3': array([[ 0.99828368],[-0.78877753],[-1.29528141],[-1.95668292],[ 1.43690228],[-0.19944769],[ 1.00068772],[ 0.8051874 ],[ 0.80680549],[ 0.26735926]], dtype=float32)}
In [ ]:​

这篇关于【tensorflow 全连接神经网络】 minist 手写数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1144337

相关文章

Xshell远程连接失败以及解决方案

《Xshell远程连接失败以及解决方案》本文介绍了在Windows11家庭版和CentOS系统中解决Xshell无法连接远程服务器问题的步骤,在Windows11家庭版中,需要通过设置添加SSH功能并... 目录一.问题描述二.原因分析及解决办法2.1添加ssh功能2.2 在Windows中开启ssh服务2

Redis连接失败:客户端IP不在白名单中的问题分析与解决方案

《Redis连接失败:客户端IP不在白名单中的问题分析与解决方案》在现代分布式系统中,Redis作为一种高性能的内存数据库,被广泛应用于缓存、消息队列、会话存储等场景,然而,在实际使用过程中,我们可能... 目录一、问题背景二、错误分析1. 错误信息解读2. 根本原因三、解决方案1. 将客户端IP添加到Re

Mysql 中的多表连接和连接类型详解

《Mysql中的多表连接和连接类型详解》这篇文章详细介绍了MySQL中的多表连接及其各种类型,包括内连接、左连接、右连接、全外连接、自连接和交叉连接,通过这些连接方式,可以将分散在不同表中的相关数据... 目录什么是多表连接?1. 内连接(INNER JOIN)2. 左连接(LEFT JOIN 或 LEFT

Spring Boot实现多数据源连接和切换的解决方案

《SpringBoot实现多数据源连接和切换的解决方案》文章介绍了在SpringBoot中实现多数据源连接和切换的几种方案,并详细描述了一个使用AbstractRoutingDataSource的实... 目录前言一、多数据源配置与切换方案二、实现步骤总结前言在 Spring Boot 中实现多数据源连接

QT实现TCP客户端自动连接

《QT实现TCP客户端自动连接》这篇文章主要为大家详细介绍了QT中一个TCP客户端自动连接的测试模型,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录版本 1:没有取消按钮 测试效果测试代码版本 2:有取消按钮测试效果测试代码版本 1:没有取消按钮 测试效果缺陷:无法手动停

W外链微信推广短连接怎么做?

制作微信推广链接的难点分析 一、内容创作难度 制作微信推广链接时,首先需要创作有吸引力的内容。这不仅要求内容本身有趣、有价值,还要能够激起人们的分享欲望。对于许多企业和个人来说,尤其是那些缺乏创意和写作能力的人来说,这是制作微信推广链接的一大难点。 二、精准定位难度 微信用户群体庞大,不同用户的需求和兴趣各异。因此,制作推广链接时需要精准定位目标受众,以便更有效地吸引他们点击并分享链接

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}