SKIL/工作流程/在实验中训练模型

2023-10-21 16:50

本文主要是介绍SKIL/工作流程/在实验中训练模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在实验中训练模型

如果你想跟踪结果并进行可重复的评估,实验对于训练模型很有用。一旦你学习了工作间,笔记本和进行实验的基本知识你就准备好用SKIL练一个模型了。


先决条件
这个文档假设你已经设置了一个工作间并在SKIL中创建了一个新的实验。创建实验后,打开“笔记本”选项卡,该选项卡将显示scala的模板笔记本,其中已设置导入和结构化训练代码。

如果你不打算动态加载任何其他依赖项,可以单击工具栏左上角的“play”按钮(形状像一个侧面三角形),以评估模板笔记本中的所有单元,并将SkilContext和deeplearning4j库放到作用域中。
如果你喜欢使用其他库,SKIL已将TensorFlow和Keras预先打包。更多信息请参见实验中的TensorFlow。

 

典型工作流程

为训练而设置的笔记本通常遵循此工作流程:

  1. 将第一个和顶部单元用于动态依赖项(可选)。
  2. 把所有常见的导入放在最上面。
  3. 实例化SkilContext并引用SkilContext.client。
  4. 添加用于加载、拆分和转换数据集的代码。
  5. 编写深度学习模型配置和超参数。
  6. 把数据传入Model.fit() 或者,如果使用多GPU,传入 ParallelWrapper.fit.
  7. 使用测试/验证/维持数据集评估模型。
  8. 将经过训练的模型和评估结果传递给SkilContext进行存储。

 

样例代码
TensorFlow、多个Keras后端和Deeplarning4J是默认情况下可用的深度学习框架。下面的示例代码使用scala语言和deeplearning4j。如果要完全下载示例笔记本,建议使用uci_quickstart_notebook.json。
如果要使用外部库,请使用笔记本第一个单元格中的%spark.dep解释器预加载要在笔记本中使用的任何依赖项。

 

%spark.dep//清除以前添加的项目和仓库
z.reset() // 添加maven仓库
z.addRepo("RepoName").url("RepoURL")// 添加Maven快照仓库
z.addRepo("RepoName").url("RepoURL").snapshot()// 添加私有Maven仓库的凭据
z.addRepo("RepoName").url("RepoURL").username("username").password("password")// 从文件系统添加项目
z.load("/path/to.jar")

在配置模型或运行代码之前,需要将必要的类导入作用域。通常,这涉及到deeplarning4j及其一些实用程序库(如ND4J和DataVec)的导入。还要记住导入SKIL实用程序,以便将模型和评估保存到SKIL存储。下面的代码拥有训练LSTM序列分类器所需要的一切。

import scala.collection.JavaConversions._import io.skymind.zeppelin.utils._
import io.skymind.modelproviders.history.client.ModelHistoryClient
import io.skymind.modelproviders.history.model._import org.deeplearning4j.datasets.iterator._
import org.deeplearning4j.datasets.iterator.impl._
import org.deeplearning4j.nn.api._
import org.deeplearning4j.nn.multilayer._
import org.deeplearning4j.nn.graph._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs._
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex
import org.deeplearning4j.nn.weights._
import org.deeplearning4j.optimize.listeners._
import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter
import org.deeplearning4j.ui.stats.StatsListener
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator
import org.deeplearning4j.eval.Evaluationimport org.datavec.api.transform._
import org.datavec.api.records.reader.RecordReader
import org.datavec.api.records.reader.SequenceRecordReader
import org.datavec.api.records.reader.impl.csv.CSVRecordReader
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader
import org.datavec.api.split.NumberedFileInputSplitimport org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.learning.config._
import org.nd4j.linalg.lossfunctions.LossFunctions._
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.primitives.Pair
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize
import org.nd4j.linalg.util.ArrayUtilimport java.io.File
import java.net.URL
import java.util.ArrayList
import java.util.Collections
import java.util.List
import java.util.Random

假设你已将数据集序列保存到单独的特征和标签文件中,则可以定义一个CSVSequenceRecordReader。它使用RecordReader基类从csv文件中提取单个序列。最后,在使用神经网络中的数据之前,必须将RecordReader传递给一个扩展DataSetIterator的类。这允许预取和批处理你的训练。

 

val trainFeatures: SequenceRecordReader = new CSVSequenceRecordReader()
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath + "/%d.csv",0,449))val trainLabels: RecordReader = new CSVRecordReader()
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath + "/%d.csv",0,449))val minibatch: Int = 10
val numLabelClasses: Int = 6val trainData: MultiDataSetIterator = new RecordReaderMultiDataSetIterator.Builder(minibatch).addSequenceReader("features", trainFeatures).addReader("labels", trainLabels).addInput("features").addOutputOneHot("labels", 0, numLabelClasses).build()

 

最后,初始化网络配置。Deeplarning4J公开了一个称为MultiLayerNetwork的简单接口,并且一个更复杂的配置ComputationGraph可用于多个输入和输出。它们类似于Keras中的两个API,ComputationGraph的工作原理与TensorFlow自己的配置非常相似。
配置网络时,必须首先使用NeuralNetConfiguration Builder定义层、输入、输出和其他超参数。然后传递到ComputationGraphMultiLayerNetwork类,不要忘记调用init()

val conf: ComputationGraphConfiguration = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER).updater(new Nesterovs(0.005, 0.9)).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(0.5).graphBuilder().addInputs("input").setInputTypes(InputType.recurrent(1)).addLayer("lstm", new GravesLSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build(), "input").addVertex("pool", new LastTimeStepVertex("input"), "lstm").addLayer("output", new OutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build(), "pool").setOutputs("output").pretrain(false).backprop(true).build()val network_model: ComputationGraph = new ComputationGraph(conf)
network_model.init()

Training the network is fairly simple. You can either use a MultipleEpochsIteratorincluded with Deeplearning4j or manually iterate through each epoch if you prefer to perform other operations such as evaluation.

训练网络相当简单。如果你愿意执行其他操作(如评估),可以使用MultipleEpochsIterator(包括deeplarming4j),也可以手动迭代每个epoch。

for (i <- 0 until nEpochs) {network_model.fit(trainData)// 在测试集上评估:val evaluation = eval(testData)var accuracy = evaluation.accuracy()var f1 = evaluation.f1()println(s"Test set evaluation at epoch $i: Accuracy = $accuracy, F1 = $f1")testData.reset()trainData.reset()
}

Certain datasets might require more complex evaluation. The code below shows you how to create an evaluation method that returns an Evaluation class which is compatible with SKIL's model storage system.

某些数据集可能需要更复杂的评估。下面的代码向你展示了如何创建一个返回与SKIL's的模型存储系统兼容的evaluation类的评估方法。

def eval(it:MultiDataSetIterator) : Evaluation = {val evaluation = new Evaluation(numLabelClasses)it.reset()while (it.hasNext()) {val ds = it.next()val prediction = network_model.outputSingle(ds.getFeatures(0))evaluation.eval(ds.getLabels(0), prediction)}return evaluation
}

最后,使用SkilContext类将模型上传到SKIL并附加评估结果。

var evaluation = eval(testData)
val modelId = skilContext.addModelToExperiment(z, network_model)
val evalId = skilContext.addEvaluationToModel(z, modelId, evaluation)

 

这篇关于SKIL/工作流程/在实验中训练模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/255671

相关文章

Linux流媒体服务器部署流程

《Linux流媒体服务器部署流程》文章详细介绍了流媒体服务器的部署步骤,包括更新系统、安装依赖组件、编译安装Nginx和RTMP模块、配置Nginx和FFmpeg,以及测试流媒体服务器的搭建... 目录流媒体服务器部署部署安装1.更新系统2.安装依赖组件3.解压4.编译安装(添加RTMP和openssl模块

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

springboot启动流程过程

《springboot启动流程过程》SpringBoot简化了Spring框架的使用,通过创建`SpringApplication`对象,判断应用类型并设置初始化器和监听器,在`run`方法中,读取配... 目录springboot启动流程springboot程序启动入口1.创建SpringApplicat

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

通过prometheus监控Tomcat运行状态的操作流程

《通过prometheus监控Tomcat运行状态的操作流程》文章介绍了如何安装和配置Tomcat,并使用Prometheus和TomcatExporter来监控Tomcat的运行状态,文章详细讲解了... 目录Tomcat安装配置以及prometheus监控Tomcat一. 安装并配置tomcat1、安装

MySQL的cpu使用率100%的问题排查流程

《MySQL的cpu使用率100%的问题排查流程》线上mysql服务器经常性出现cpu使用率100%的告警,因此本文整理一下排查该问题的常规流程,文中通过代码示例讲解的非常详细,对大家的学习或工作有一... 目录1. 确认CPU占用来源2. 实时分析mysql活动3. 分析慢查询与执行计划4. 检查索引与表

Git提交代码详细流程及问题总结

《Git提交代码详细流程及问题总结》:本文主要介绍Git的三大分区,分别是工作区、暂存区和版本库,并详细描述了提交、推送、拉取代码和合并分支的流程,文中通过代码介绍的非常详解,需要的朋友可以参考下... 目录1.git 三大分区2.Git提交、推送、拉取代码、合并分支详细流程3.问题总结4.git push