MNIST3_numpy手写全连接神经网络

2023-11-02 14:32

本文主要是介绍MNIST3_numpy手写全连接神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、链式求导
  • 二、numpy layer和反向传播
    • 反向传播
  • 三、MNIST训练及测试

一、链式求导

在这里插入图片描述

二、numpy layer和反向传播

全部脚本见笔者github: numpynn.py


import numpy as npclass npLayer():def __init__(self, n_input, n_out, activation=None, weights=None,bias=None):self.weights = weights if weights is not None else np.random.randn(n_input, n_out) * np.sqrt(1 / n_out)self.bias = bias if bias is not None else np.random.randn(n_out) * 0.1self.activation = activation self.last_activation = None self.error = None self.delta = None def activate(self, x):# 前向传播r = np.dot(x, self.weights) + self.biasself.last_activation = self.apply_activation(r)return self.last_activation def apply_activation(self, r):# 计算激活函数的输出if self.activation is None:return relif self.activation == 'relu':return np.maximum(r, 0)elif self.activation == 'tanh':return np.tanh(r)elif self.activation == 'sigmoid':return 1/(1 + np.exp(-r))return rdef apply_activation_derivative(self, act_r):# 计算激活函数的导数if self.activation is None:return np.ones_like(act_r)elif self.activation == 'relu':return (act_r > 0) * 1elif self.activation == 'tanh':return 1 - act_r ** 2elif self.activation == 'sigmoid':return act_r * (1 - act_r)return act_rdef __call__(self, x):return self.activate(x)

反向传播

    def backpropagation(self, x, y, learning_rate):# 反向传播算法实现## 从后向前计算梯度 output = self.feed_forward(x) # 最后层输出layer_len = len(self._layers)for i in reversed(range(layer_len)):layer = self._layers[i] # 如果是输出层if layer  == self._layers[-1]:delta_i = layer.apply_activation_derivative(output)layer.error = output - ylayer.delta = layer.error * delta_ielse:next_layer = self._layers[i + 1]delta_i = layer.apply_activation_derivative(layer.last_activation)layer.error = np.dot(next_layer.weights, next_layer.delta)layer.delta = layer.error * delta_i# 梯度下降for i in range(layer_len):layer = self._layers[i]o_i = np.atleast_2d(x if i == 0 else self._layers[i - 1].last_activation)layer.weights -= layer.delta * o_i.T * learning_rate

三、MNIST训练及测试


if __name__ == '__main__':mnistdf = get_ministdata()te_index = mnistdf.sample(frac=0.8).index.tolist()mnist_te = mnistdf.loc[te_index, :]mnist_tr = mnistdf.loc[~mnistdf.index.isin(te_index), :]x_tr, y_tr = mnist_tr.iloc[:, :-1].values, mnist_tr.iloc[:, -1].valuesx_te, y_te = mnist_te.iloc[:, :-1].values, mnist_te.iloc[:, -1].valuesprint(x_te.shape)nn = NeuralNetwork()nn.add_layer(npLayer(784, 128, 'relu')) nn.add_layer(npLayer(128, 10, 'sigmoid'))st = time.perf_counter()mses, accs = nn.train(x_tr, x_te, y_tr, y_te, 0.01, 150)cost_ = time.perf_counter() - stprint(f'cost: {cost_:.2f}s',accs)
 ================================================================================
Epoch: # 85, MSE: 0.00713
Accuracy: 93.93 % ================================================================================
Epoch: # 90, MSE: 0.00654
Accuracy: 94.09 % ================================================================================
Epoch: # 95, MSE: 0.00600
Accuracy: 94.27 % ================================================================================
Epoch: # 100, MSE: 0.00558
Accuracy: 94.41 % ================================================================================
Epoch: # 105, MSE: 0.00514
Accuracy: 94.53 % ================================================================================
Epoch: # 110, MSE: 0.00479
Accuracy: 94.65 % ================================================================================
Epoch: # 115, MSE: 0.00447
Accuracy: 94.75 % ================================================================================
Epoch: # 120, MSE: 0.00417
Accuracy: 94.84 % ================================================================================
Epoch: # 125, MSE: 0.00393
Accuracy: 94.93 % ================================================================================
Epoch: # 130, MSE: 0.00370
Accuracy: 94.98 % ================================================================================
Epoch: # 135, MSE: 0.00350
Accuracy: 95.03 %================================================================================
Epoch: # 140, MSE: 0.00332
Accuracy: 95.08 %================================================================================
Epoch: # 145, MSE: 0.00316
Accuracy: 95.12 %================================================================================
Epoch: # 150, MSE: 0.00303
Accuracy: 95.14 %
cost: 1104.11s [0.2034285714285714, 0.5135714285714286, 0.5907142857142857, 0.6798928571428572, 0.74375, 0.7954285714285715
, 0.8364821428571428, 0.863125, 0.8833571428571428, 0.8975178571428571, 0.9077857142857142, 0.9149285714285714, 0.9213214285714286
, 0.9264821428571427, 0.9302142857142858, 0.9336071428571429, 0.9372678571428571, 0.9392857142857143, 0.9408928571428572, 0.9427321428571429
, 0.9440535714285714, 0.94525, 0.9465178571428572, 0.9475178571428572, 0.9483571428571429, 0.9493035714285715, 0.9498214285714286
, 0.9502857142857143, 0.95075, 0.9511607142857144, 0.9513571428571429]

这篇关于MNIST3_numpy手写全连接神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/331270

相关文章

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

SQL 中多表查询的常见连接方式详解

《SQL中多表查询的常见连接方式详解》本文介绍SQL中多表查询的常见连接方式,包括内连接(INNERJOIN)、左连接(LEFTJOIN)、右连接(RIGHTJOIN)、全外连接(FULLOUTER... 目录一、连接类型图表(ASCII 形式)二、前置代码(创建示例表)三、连接方式代码示例1. 内连接(I

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

java如何通过Kerberos认证方式连接hive

《java如何通过Kerberos认证方式连接hive》该文主要介绍了如何在数据源管理功能中适配不同数据源(如MySQL、PostgreSQL和Hive),特别是如何在SpringBoot3框架下通过... 目录Java实现Kerberos认证主要方法依赖示例续期连接hive遇到的问题分析解决方式扩展思考总

Python中连接不同数据库的方法总结

《Python中连接不同数据库的方法总结》在数据驱动的现代应用开发中,Python凭借其丰富的库和强大的生态系统,成为连接各种数据库的理想编程语言,下面我们就来看看如何使用Python实现连接常用的几... 目录一、连接mysql数据库二、连接PostgreSQL数据库三、连接SQLite数据库四、连接Mo

oracle如何连接登陆SYS账号

《oracle如何连接登陆SYS账号》在Navicat12中连接Oracle11g的SYS用户时,如果设置了新密码但连接失败,可能是因为需要以SYSDBA或SYSOPER角色连接,解决方法是确保在连接... 目录oracle连接登陆NmOtMSYS账号工具问题解决SYS用户总结oracle连接登陆SYS账号

VScode连接远程Linux服务器环境配置图文教程

《VScode连接远程Linux服务器环境配置图文教程》:本文主要介绍如何安装和配置VSCode,包括安装步骤、环境配置(如汉化包、远程SSH连接)、语言包安装(如C/C++插件)等,文中给出了详... 目录一、安装vscode二、环境配置1.中文汉化包2.安装remote-ssh,用于远程连接2.1安装2

关于rpc长连接与短连接的思考记录

《关于rpc长连接与短连接的思考记录》文章总结了RPC项目中长连接和短连接的处理方式,包括RPC和HTTP的长连接与短连接的区别、TCP的保活机制、客户端与服务器的连接模式及其利弊分析,文章强调了在实... 目录rpc项目中的长连接与短连接的思考什么是rpc项目中的长连接和短连接与tcp和http的长连接短

numpy求解线性代数相关问题

《numpy求解线性代数相关问题》本文主要介绍了numpy求解线性代数相关问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 在numpy中有numpy.array类型和numpy.mat类型,前者是数组类型,后者是矩阵类型。数组