java(kotlin) ai框架djl

2024-06-13 09:28
文章标签 java ai 框架 kotlin djl

本文主要是介绍java(kotlin) ai框架djl,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl--><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>0.28.0</version></dependency><!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘│┌─────────▼─────────┐│       Engine      │└───────┬─┬─────────┘│ │┌───────▼─▼─────────┐│     NDManager     │└───────┬─┬─────────┘│ │┌─────────▼─▼───────────┐│    Dataset └─────────┬─────────────┘│┌─────────▼─────────────┐│  Trainer / Predictor  │└───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;import java.io.IOException;
      import java.nio.file.Paths;public class DjlExample {public static void main(String[] args) throws IOException, ModelException, TranslateException {// 加载模型Criteria<Image, Classifications> criteria = Criteria.builder().optEngine("TensorFlow") // 选择引擎.setTypes(Image.class, Classifications.class).optModelPath(Paths.get("path/to/model")).build();try (Model model = ModelZoo.loadModel(criteria);Predictor<Image, Classifications> predictor = model.newPredictor()) {// 加载图像Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));// 进行推理Classifications result = predictor.predict(img);System.out.println(result);}}
      }
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample {public static void main(String[] args) throws IOException, TranslateException {NDManager manager = NDManager.newBaseManager();FashionMnist fashionMnist = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true) // 32 is the batch size.optLimit(Long.MAX_VALUE) // Use this to limit the number of samples.build();fashionMnist.prepare();Model model = Model.newInstance("fashion-mnist-model");TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()).addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)) {trainer.initialize(new long[]{1, 28, 28}); // Example shape for image dataMetrics metrics = new Metrics();trainer.setMetrics(metrics);for (Batch batch : trainer.iterateDataset(fashionMnist)) {EasyTrain.trainBatch(trainer, batch);trainer.step();batch.close();}trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));}}
      }

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DJLEngineExample {@Throws(ModelException::class, TranslateException::class, IOException::class)@JvmStaticfun main(args: Array<String>) {// Initialize the modelval model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained modelmodel.load(Paths.get("path/to/your/model")) // Ensure the path is correct// Define a translator for data preprocessing and postprocessingval translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {val manager = ctx.ndManagerval array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessaryreturn NDList(array)}override fun processOutput(ctx: TranslatorContext, list: NDList): Float {// Assuming the output is a single scalar valuereturn list[0].getFloat() // Use getFloat() to get the scalar value}override fun getBatchifier(): Batchifier? {return null // Or implement batching if needed}}model.newPredictor(translator).use { predictor ->val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shapeval output = predictor.predict(input)println("Prediction: $output")}}
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManagerobject NDManagerExample {@JvmStaticfun main(args: Array<String>) {NDManager.newBaseManager().use { manager ->val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))println("Array: $array")// Perform operationsval result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Pathsobject DjlTrainerDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load datasetval trainDataset: RandomAccessDataset =FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()trainDataset.prepare(ProgressBar())// Define modelval model = Model.newInstance("mlp")model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))// Define training configurationval config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener())model.newTrainer(config).use { trainer ->trainer.initialize(Shape(1, (28 * 28).toLong()))for (epoch in 0..9) {for (batch in trainer.iterateDataset(trainDataset)) {trainer.step()batch.close()}trainer.notifyListeners { listener: TrainingListener ->listener.onEpoch(trainer)}}model.save(Paths.get("model"), "mlp")}}
    }
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DjlPredictorDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load modelval model = Model.newInstance("mlp")model.load(Paths.get("model"), "mlp")// Define Translatorval translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {return NDList(input.reshape(Shape(1, (28 * 28).toLong())))}override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {// Assuming the output NDArray is the first element in NDListval probabilities = list.singletonOrThrow()return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels}override fun getBatchifier(): Batchifier {return Batchifier.STACK}}model.newPredictor(translator).use { predictor ->val manager = NDManager.newBaseManager()val array = manager.ones(Shape(1, (28 * 28).toLong()))val classifications = predictor.predict(array)println(classifications)}}
    }

这篇关于java(kotlin) ai框架djl的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056911

相关文章

springboot将lib和jar分离的操作方法

《springboot将lib和jar分离的操作方法》本文介绍了如何通过优化pom.xml配置来减小SpringBoot项目的jar包大小,主要通过使用spring-boot-maven-plugin... 遇到一个问题,就是每次maven package或者maven install后target中的ja

Java中八大包装类举例详解(通俗易懂)

《Java中八大包装类举例详解(通俗易懂)》:本文主要介绍Java中的包装类,包括它们的作用、特点、用途以及如何进行装箱和拆箱,包装类还提供了许多实用方法,如转换、获取基本类型值、比较和类型检测,... 目录一、包装类(Wrapper Class)1、简要介绍2、包装类特点3、包装类用途二、装箱和拆箱1、装

如何利用Java获取当天的开始和结束时间

《如何利用Java获取当天的开始和结束时间》:本文主要介绍如何使用Java8的LocalDate和LocalDateTime类获取指定日期的开始和结束时间,展示了如何通过这些类进行日期和时间的处... 目录前言1. Java日期时间API概述2. 获取当天的开始和结束时间代码解析运行结果3. 总结前言在J

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

最长公共子序列问题的深度分析与Java实现方式

《最长公共子序列问题的深度分析与Java实现方式》本文详细介绍了最长公共子序列(LCS)问题,包括其概念、暴力解法、动态规划解法,并提供了Java代码实现,暴力解法虽然简单,但在大数据处理中效率较低,... 目录最长公共子序列问题概述问题理解与示例分析暴力解法思路与示例代码动态规划解法DP 表的构建与意义动

Java多线程父线程向子线程传值问题及解决

《Java多线程父线程向子线程传值问题及解决》文章总结了5种解决父子之间数据传递困扰的解决方案,包括ThreadLocal+TaskDecorator、UserUtils、CustomTaskDeco... 目录1 背景2 ThreadLocal+TaskDecorator3 RequestContextH

关于Spring @Bean 相同加载顺序不同结果不同的问题记录

《关于Spring@Bean相同加载顺序不同结果不同的问题记录》本文主要探讨了在Spring5.1.3.RELEASE版本下,当有两个全注解类定义相同类型的Bean时,由于加载顺序不同,最终生成的... 目录问题说明测试输出1测试输出2@Bean注解的BeanDefiChina编程nition加入时机总结问题说明

java父子线程之间实现共享传递数据

《java父子线程之间实现共享传递数据》本文介绍了Java中父子线程间共享传递数据的几种方法,包括ThreadLocal变量、并发集合和内存队列或消息队列,并提醒注意并发安全问题... 目录通过 ThreadLocal 变量共享数据通过并发集合共享数据通过内存队列或消息队列共享数据注意并发安全问题总结在 J

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首