将分类网络应用在android中 part2,用自己的训练结果应用android

2024-02-04 03:48

本文主要是介绍将分类网络应用在android中 part2,用自己的训练结果应用android,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

准备工作

编译应用

代码解读

 

准备工作

1.保存checkpoint

可以参考之前的一篇利用tf slim进行分类网络训练的博客,博客地址,如果按照里面的操作步骤进行训练网络,我们会得到保存下来的checkpoint文件。

model.ckpt-5000.data-00000-of-00001 --> 保存了当前参数值
model.ckpt-5000.index --> 保存了当前参数名
model.ckpt-5000.meta --> 保存了当前graph结构图

这样训练的脚本就直接帮我们完成了checkpoint的保存。

但是如果是自己实现的网络结构和网络训练,那我们需要使用下面的代码来保存checkpoint,然后我们同样也会得到这三类文件。

saver = tf.train.Saver()
saver.save(sess, './data/train_logs_1/model.chkp')

2.根据meta生成freezed pb

下面就是要根据生成的checkpoint文件来生成freezed protobuf文件。什么是freezed呢?其实就是将参数的值和graph结合起来保存成pb文件,这样后续使用的时候就只需要直接输入input进行计算就好了,也不用还原网络结构。当然pb文件里面其实都是二进制的信息,也无法还原网络结构的,我们在运算的时候只能按照里面记录的运算方式进行计算。

import tensorflow as tfmeta_path = './data/train_logs_1/model.ckpt-5000.meta' # Your .meta file
output_node_names = ['MobilenetV2/Predictions/Reshape_1']    # Output nodeswith tf.Session() as sess:# Restore the graphsaver = tf.train.import_meta_graph(meta_path)# Load weightssaver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))# Freeze the graphfrozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,output_node_names)# Save the frozen graphwith open('./data/train_logs_1/freeze_graph_5000.pb', 'wb') as f:f.write(frozen_graph_def.SerializeToString())

上面的代码可以保存一个freezed_graph.pb文件。在使用convert_variables_to_constants的时候需要一个output_node_names,如果不知道output的具体名字可以先用如下的代码查看

import tensorflow as tfmeta_path = './data/train_logs_1/model.ckpt-5000.meta' # Your .meta filewith tf.Session() as sess:# Restore the graphsaver = tf.train.import_meta_graph(meta_path)# Load weightssaver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))graph = tf.get_default_graph()with open('./data/train_logs_1/operations_5000.txt', 'wb') as f:for op in graph.get_operations():f.writelines(str(op.name) + ',' + str(op.values()) + '\n')

这里会将graph里面包含的所有操作打印出来,因为内容比较多,所以存入文件方便查看,可以从operations_5000.txt中看到在loss之前最后一个输出就是MobilenetV2/Predictions/Reshape_1,所以我们在convert_variables_to_constants中填入的output_node_names为['MobilenetV2/Predictions/Reshape_1'],其实可以填入好多个output,但是我们的分类网络只需要一个。

MobilenetV2/Predictions/Reshape_1,(<tf.Tensor 'MobilenetV2/Predictions/Reshape_1:0' shape=(32, 243) dtype=float32>,)
softmax_cross_entropy_loss/Rank,(<tf.Tensor 'softmax_cross_entropy_loss/Rank:0' shape=() dtype=int32>,)

3.测试freezed pb

前面两个步骤后其实freezed pb文件就已经保存成功了,但是还需要测试一下我们保存的pb文件是否可靠,是否可以通过load这个pb文件就进行预测。

首先我们需要load pb文件

def load_graph(model_file):graph = tf.Graph()graph_def = tf.GraphDef()with open(model_file, "rb") as f:graph_def.ParseFromString(f.read())with graph.as_default():tf.import_graph_def(graph_def)return graph

接着取出input和output的tensor

  input_operation = graph.get_operation_by_name(input_name)output_operation = graph.get_operation_by_name(output_name)

最后sess run就可以得到结果了

  with tf.Session(graph=graph) as sess:results = sess.run(output_operation.outputs[0], {input_operation.outputs[0]: t})results = np.squeeze(results)print(results.shape)top_k = results.argsort()[-5:][::-1]labels = load_labels(label_file)

虽然测试pb文件代码很简单,但是我们可能会遇到两个坑。

第一个可能运行后会报错。因为我们打印出来的results的shape是[32, 243],243是我们分类的类别数,但是为何有32个243的数组呢?

(1, 224, 224, 3)
2018-09-04 18:19:58.424800: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:03:00.0, compute capability: 6.1)
(32, 243)
Traceback (most recent call last):File "test_freeze_meta.py", line 80, in <module>print(labels[i], results[i])
TypeError: only integer scalar arrays can be converted to a scalar index

从我们前面打印出来的output的size可以看出MobilenetV2/Predictions/Reshape_1的shape是(32, 243),是因为32是我们之前training的batch size,而我们保存meta的时候将这个size保存下来了。所以导致我们predict的时候出来的结果也是(32, 243),但是我们其实只预测了一张图片。

如果遇到了这个问题,可以将training时的batch size改成1,基于前面的checkpoint再进行一次训练生成新的meta文件,然后重复上面的步骤进行操作即可。比如我这边会用下面的命令接着进行training

python train_image_classifier.py \--train_dir=./data/train_logs_1 \--dataset_dir=./data/mydata \--dataset_name=mydata \--dataset_split_name=train \--model_name=mobilenet_v2 \--train_image_size=224 \--batch_size=1

会生成model.ckpt-5001.data-00000-of-00001,model.ckpt-5001.index和model.ckpt-5001.meta文件,然后重新生成operations_5001.txt,可以发现里面shape已经变过来了

MobilenetV2/Predictions/Reshape_1,(<tf.Tensor 'MobilenetV2/Predictions/Reshape_1:0' shape=(1, 243) dtype=float32>,)
softmax_cross_entropy_loss/Rank,(<tf.Tensor 'softmax_cross_entropy_loss/Rank:0' shape=() dtype=int32>,)

后面进行预测就不会有报错。但是其实并没有结束,因为我们还可能遇到第二个坑。

预测的结果特别不准确,而且多跑几次会发现每次结果都不一样。这是因为batch normalization和dropout的随机性导致的

    with slim.arg_scope([slim.batch_norm, slim.dropout],is_training=is_training):

代码中在搭建网络的时候很清楚的对这两种操作区分了是否是training状态。所以我们现在的做法是搭建网络的时候传入这个参数为false,然后进行一次training生成checkpoint 5002。

    network_fn = nets_factory.get_network_fn(FLAGS.model_name,num_classes=(dataset.num_classes - FLAGS.labels_offset),weight_decay=FLAGS.weight_decay,is_training=False)

如果对5002 checkpoint进行评估会发现结果是非常不准确的,没关系,我们只需要用到他的meta文件。

然后在用第二步中的操作生成freezed pb文件

meta_path = './data/train_logs_1/model.ckpt-5002.meta' # Your .meta file#这里restore的checkpoint需要是准确率比较高的checkpoint,比如ckpt-5001
#如果让latest_checkpoint取到的是5001呢,很简单,修改train_logs_1目录下的checkpoint文件
#修改model_checkpoint_path: "model.ckpt-5001",这样就会自动取5001为checkpoint来恢复数据了
saver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))

然后再进行预测,一切都正常了。比如预测file_name = "./backup/mydata/km335_back/km335_back.jpg"文件,结果是

(243,)
('150:km335_back', 0.99930513)
('151:km335_front', 0.00019936822)
('191:km711_front', 0.00018467742)
('193:km712_front', 0.00015364563)
('220:kmmerge123_back', 4.9729293e-05)

并不是每个人都会遇到这两个问题,如果是自己搭建网络,自己保存checkpoint我想是可以避免的,但是tf-slim是用slim.learning.train接口进行训练和保存checkpoint,所以保存形式不太可控。

代码实现:freeze_meta.py   test_freeze_meta.py

 

编译应用

如果按照上一篇博文(链接)进行了实操,那这一步就会非常容易了。

首先将上一步编译出来的pb文件,和我们分类的label文件拷贝放入tensorflow/examples/android/assets

然后修改ClassifierActivity.java中的代码如下

  private static final String INPUT_NAME = "MobilenetV2/input";private static final String OUTPUT_NAME = "MobilenetV2/Predictions/Reshape_1";  private static final String MODEL_FILE = "file:///android_asset/freeze_graph_5002.pb";private static final String LABEL_FILE = "file:///android_asset/labels.txt";

接着用bazel进行编译

bazel build //tensorflow/examples/android:tensorflow_coin

生成的apk放在了bazel-bin/tensorflow/examples/android目录下,安装启动即可。

但是我们可能会遇到另外一个坑,安装apk后启动会crash,从adb log看到的错误是

09-05 14:07:58.741 16317 16441 E AndroidRuntime: java.lang.IllegalArgumentException: Cannot assign a device for operation 'MobilenetV2/input': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
09-05 14:07:58.741 16317 16441 E AndroidRuntime: 	 [[Node: MobilenetV2/input = Identity[T=DT_FLOAT, _device="/device:GPU:0"](fifo_queue_Dequeue)]]

这是因为我们训练的时候用的是GPU,保存的pb中指定了用GPU进行load,而手机中只有CPU,所以所发生错误。

改动方法是训练的时候添加一个flag --clone_on_cpu=True,就可以将我们的meta保存device指定CPU。

 

代码解读

android代码中关于分类网络的主要是两个文件,一个是ClassifierActivity.java,另一个是TensorFlowImageClassifier.java

1.ClassifierActivity.java

这个文件主要负责camera的preview,将preview中的图片传递给TensorFlowImageClassifier进行分类网络的预测,最后显示预测结果。

# 创建classifier实例
classifier =TensorFlowImageClassifier.create(getAssets(),MODEL_FILE,LABEL_FILE,INPUT_SIZE,IMAGE_MEAN,IMAGE_STD,INPUT_NAME,OUTPUT_NAME);# 调用recognizeImage进行图像识别
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);# 显示预测结果
resultsView.setResults(results);

2.TensorFlowImageClassifier.java

主要负责调用TensorFlowInferenceInterface类的接口进行预测。

# 实例化TensorFlowInferenceInterface,同时会将model载入
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);# 传入input的image数据
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);# 进行计算
inferenceInterface.run(outputNames, logStats);#取出计算结果
inferenceInterface.fetch(outputName, outputs);

 

以上就是全部内容,如果在操作过程中遇到了任何问题可以给我留言,谢谢阅读。

这篇关于将分类网络应用在android中 part2,用自己的训练结果应用android的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/676276

相关文章

Java逻辑运算符之&&、|| 与&、 |的区别及应用

《Java逻辑运算符之&&、||与&、|的区别及应用》:本文主要介绍Java逻辑运算符之&&、||与&、|的区别及应用的相关资料,分别是&&、||与&、|,并探讨了它们在不同应用场景中... 目录前言一、基本概念与运算符介绍二、短路与与非短路与:&& 与 & 的区别1. &&:短路与(AND)2. &:非短

Android WebView无法加载H5页面的常见问题和解决方法

《AndroidWebView无法加载H5页面的常见问题和解决方法》AndroidWebView是一种视图组件,使得Android应用能够显示网页内容,它基于Chromium,具备现代浏览器的许多功... 目录1. WebView 简介2. 常见问题3. 网络权限设置4. 启用 JavaScript5. D

Android如何获取当前CPU频率和占用率

《Android如何获取当前CPU频率和占用率》最近在优化App的性能,需要获取当前CPU视频频率和占用率,所以本文小编就来和大家总结一下如何在Android中获取当前CPU频率和占用率吧... 最近在优化 App 的性能,需要获取当前 CPU视频频率和占用率,通过查询资料,大致思路如下:目前没有标准的

Spring AI集成DeepSeek三步搞定Java智能应用的详细过程

《SpringAI集成DeepSeek三步搞定Java智能应用的详细过程》本文介绍了如何使用SpringAI集成DeepSeek,一个国内顶尖的多模态大模型,SpringAI提供了一套统一的接口,简... 目录DeepSeek 介绍Spring AI 是什么?Spring AI 的主要功能包括1、环境准备2

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav

Android开发中gradle下载缓慢的问题级解决方法

《Android开发中gradle下载缓慢的问题级解决方法》本文介绍了解决Android开发中Gradle下载缓慢问题的几种方法,本文给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、网络环境优化二、Gradle版本与配置优化三、其他优化措施针对android开发中Gradle下载缓慢的问

MobaXterm远程登录工具功能与应用小结

《MobaXterm远程登录工具功能与应用小结》MobaXterm是一款功能强大的远程终端软件,主要支持SSH登录,拥有多种远程协议,实现跨平台访问,它包括多会话管理、本地命令行执行、图形化界面集成和... 目录1. 远程终端软件概述1.1 远程终端软件的定义与用途1.2 远程终端软件的关键特性2. 支持的

Android 悬浮窗开发示例((动态权限请求 | 前台服务和通知 | 悬浮窗创建 )

《Android悬浮窗开发示例((动态权限请求|前台服务和通知|悬浮窗创建)》本文介绍了Android悬浮窗的实现效果,包括动态权限请求、前台服务和通知的使用,悬浮窗权限需要动态申请并引导... 目录一、悬浮窗 动态权限请求1、动态请求权限2、悬浮窗权限说明3、检查动态权限4、申请动态权限5、权限设置完毕后

Android里面的Service种类以及启动方式

《Android里面的Service种类以及启动方式》Android中的Service分为前台服务和后台服务,前台服务需要亮身份牌并显示通知,后台服务则有启动方式选择,包括startService和b... 目录一句话总结:一、Service 的两种类型:1. 前台服务(必须亮身份牌)2. 后台服务(偷偷干

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep