java(kotlin) ai框架djl

2024-06-13 09:28
文章标签 java ai 框架 kotlin djl

本文主要是介绍java(kotlin) ai框架djl,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl--><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>0.28.0</version></dependency><!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘│┌─────────▼─────────┐│       Engine      │└───────┬─┬─────────┘│ │┌───────▼─▼─────────┐│     NDManager     │└───────┬─┬─────────┘│ │┌─────────▼─▼───────────┐│    Dataset └─────────┬─────────────┘│┌─────────▼─────────────┐│  Trainer / Predictor  │└───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;import java.io.IOException;
      import java.nio.file.Paths;public class DjlExample {public static void main(String[] args) throws IOException, ModelException, TranslateException {// 加载模型Criteria<Image, Classifications> criteria = Criteria.builder().optEngine("TensorFlow") // 选择引擎.setTypes(Image.class, Classifications.class).optModelPath(Paths.get("path/to/model")).build();try (Model model = ModelZoo.loadModel(criteria);Predictor<Image, Classifications> predictor = model.newPredictor()) {// 加载图像Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));// 进行推理Classifications result = predictor.predict(img);System.out.println(result);}}
      }
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample {public static void main(String[] args) throws IOException, TranslateException {NDManager manager = NDManager.newBaseManager();FashionMnist fashionMnist = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true) // 32 is the batch size.optLimit(Long.MAX_VALUE) // Use this to limit the number of samples.build();fashionMnist.prepare();Model model = Model.newInstance("fashion-mnist-model");TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()).addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)) {trainer.initialize(new long[]{1, 28, 28}); // Example shape for image dataMetrics metrics = new Metrics();trainer.setMetrics(metrics);for (Batch batch : trainer.iterateDataset(fashionMnist)) {EasyTrain.trainBatch(trainer, batch);trainer.step();batch.close();}trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));}}
      }

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DJLEngineExample {@Throws(ModelException::class, TranslateException::class, IOException::class)@JvmStaticfun main(args: Array<String>) {// Initialize the modelval model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained modelmodel.load(Paths.get("path/to/your/model")) // Ensure the path is correct// Define a translator for data preprocessing and postprocessingval translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {val manager = ctx.ndManagerval array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessaryreturn NDList(array)}override fun processOutput(ctx: TranslatorContext, list: NDList): Float {// Assuming the output is a single scalar valuereturn list[0].getFloat() // Use getFloat() to get the scalar value}override fun getBatchifier(): Batchifier? {return null // Or implement batching if needed}}model.newPredictor(translator).use { predictor ->val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shapeval output = predictor.predict(input)println("Prediction: $output")}}
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManagerobject NDManagerExample {@JvmStaticfun main(args: Array<String>) {NDManager.newBaseManager().use { manager ->val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))println("Array: $array")// Perform operationsval result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Pathsobject DjlTrainerDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load datasetval trainDataset: RandomAccessDataset =FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()trainDataset.prepare(ProgressBar())// Define modelval model = Model.newInstance("mlp")model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))// Define training configurationval config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener())model.newTrainer(config).use { trainer ->trainer.initialize(Shape(1, (28 * 28).toLong()))for (epoch in 0..9) {for (batch in trainer.iterateDataset(trainDataset)) {trainer.step()batch.close()}trainer.notifyListeners { listener: TrainingListener ->listener.onEpoch(trainer)}}model.save(Paths.get("model"), "mlp")}}
    }
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DjlPredictorDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load modelval model = Model.newInstance("mlp")model.load(Paths.get("model"), "mlp")// Define Translatorval translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {return NDList(input.reshape(Shape(1, (28 * 28).toLong())))}override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {// Assuming the output NDArray is the first element in NDListval probabilities = list.singletonOrThrow()return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels}override fun getBatchifier(): Batchifier {return Batchifier.STACK}}model.newPredictor(translator).use { predictor ->val manager = NDManager.newBaseManager()val array = manager.ones(Shape(1, (28 * 28).toLong()))val classifications = predictor.predict(array)println(classifications)}}
    }

这篇关于java(kotlin) ai框架djl的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056911

相关文章

SpringBoot中SM2公钥加密、私钥解密的实现示例详解

《SpringBoot中SM2公钥加密、私钥解密的实现示例详解》本文介绍了如何在SpringBoot项目中实现SM2公钥加密和私钥解密的功能,通过使用Hutool库和BouncyCastle依赖,简化... 目录一、前言1、加密信息(示例)2、加密结果(示例)二、实现代码1、yml文件配置2、创建SM2工具

Spring WebFlux 与 WebClient 使用指南及最佳实践

《SpringWebFlux与WebClient使用指南及最佳实践》WebClient是SpringWebFlux模块提供的非阻塞、响应式HTTP客户端,基于ProjectReactor实现,... 目录Spring WebFlux 与 WebClient 使用指南1. WebClient 概述2. 核心依

Spring Boot @RestControllerAdvice全局异常处理最佳实践

《SpringBoot@RestControllerAdvice全局异常处理最佳实践》本文详解SpringBoot中通过@RestControllerAdvice实现全局异常处理,强调代码复用、统... 目录前言一、为什么要使用全局异常处理?二、核心注解解析1. @RestControllerAdvice2

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

Spring事务传播机制最佳实践

《Spring事务传播机制最佳实践》Spring的事务传播机制为我们提供了优雅的解决方案,本文将带您深入理解这一机制,掌握不同场景下的最佳实践,感兴趣的朋友一起看看吧... 目录1. 什么是事务传播行为2. Spring支持的七种事务传播行为2.1 REQUIRED(默认)2.2 SUPPORTS2

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.

Java进程异常故障定位及排查过程

《Java进程异常故障定位及排查过程》:本文主要介绍Java进程异常故障定位及排查过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、故障发现与初步判断1. 监控系统告警2. 日志初步分析二、核心排查工具与步骤1. 进程状态检查2. CPU 飙升问题3. 内存

java中新生代和老生代的关系说明

《java中新生代和老生代的关系说明》:本文主要介绍java中新生代和老生代的关系说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、内存区域划分新生代老年代二、对象生命周期与晋升流程三、新生代与老年代的协作机制1. 跨代引用处理2. 动态年龄判定3. 空间分

Java设计模式---迭代器模式(Iterator)解读

《Java设计模式---迭代器模式(Iterator)解读》:本文主要介绍Java设计模式---迭代器模式(Iterator),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,... 目录1、迭代器(Iterator)1.1、结构1.2、常用方法1.3、本质1、解耦集合与遍历逻辑2、统一

Java内存分配与JVM参数详解(推荐)

《Java内存分配与JVM参数详解(推荐)》本文详解JVM内存结构与参数调整,涵盖堆分代、元空间、GC选择及优化策略,帮助开发者提升性能、避免内存泄漏,本文给大家介绍Java内存分配与JVM参数详解,... 目录引言JVM内存结构JVM参数概述堆内存分配年轻代与老年代调整堆内存大小调整年轻代与老年代比例元空