java(kotlin) ai框架djl

2024-06-13 09:28
文章标签 java ai 框架 kotlin djl

本文主要是介绍java(kotlin) ai框架djl,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl--><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>0.28.0</version></dependency><!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘│┌─────────▼─────────┐│       Engine      │└───────┬─┬─────────┘│ │┌───────▼─▼─────────┐│     NDManager     │└───────┬─┬─────────┘│ │┌─────────▼─▼───────────┐│    Dataset └─────────┬─────────────┘│┌─────────▼─────────────┐│  Trainer / Predictor  │└───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;import java.io.IOException;
      import java.nio.file.Paths;public class DjlExample {public static void main(String[] args) throws IOException, ModelException, TranslateException {// 加载模型Criteria<Image, Classifications> criteria = Criteria.builder().optEngine("TensorFlow") // 选择引擎.setTypes(Image.class, Classifications.class).optModelPath(Paths.get("path/to/model")).build();try (Model model = ModelZoo.loadModel(criteria);Predictor<Image, Classifications> predictor = model.newPredictor()) {// 加载图像Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));// 进行推理Classifications result = predictor.predict(img);System.out.println(result);}}
      }
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample {public static void main(String[] args) throws IOException, TranslateException {NDManager manager = NDManager.newBaseManager();FashionMnist fashionMnist = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true) // 32 is the batch size.optLimit(Long.MAX_VALUE) // Use this to limit the number of samples.build();fashionMnist.prepare();Model model = Model.newInstance("fashion-mnist-model");TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()).addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)) {trainer.initialize(new long[]{1, 28, 28}); // Example shape for image dataMetrics metrics = new Metrics();trainer.setMetrics(metrics);for (Batch batch : trainer.iterateDataset(fashionMnist)) {EasyTrain.trainBatch(trainer, batch);trainer.step();batch.close();}trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));}}
      }

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DJLEngineExample {@Throws(ModelException::class, TranslateException::class, IOException::class)@JvmStaticfun main(args: Array<String>) {// Initialize the modelval model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained modelmodel.load(Paths.get("path/to/your/model")) // Ensure the path is correct// Define a translator for data preprocessing and postprocessingval translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {val manager = ctx.ndManagerval array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessaryreturn NDList(array)}override fun processOutput(ctx: TranslatorContext, list: NDList): Float {// Assuming the output is a single scalar valuereturn list[0].getFloat() // Use getFloat() to get the scalar value}override fun getBatchifier(): Batchifier? {return null // Or implement batching if needed}}model.newPredictor(translator).use { predictor ->val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shapeval output = predictor.predict(input)println("Prediction: $output")}}
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManagerobject NDManagerExample {@JvmStaticfun main(args: Array<String>) {NDManager.newBaseManager().use { manager ->val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))println("Array: $array")// Perform operationsval result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Pathsobject DjlTrainerDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load datasetval trainDataset: RandomAccessDataset =FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()trainDataset.prepare(ProgressBar())// Define modelval model = Model.newInstance("mlp")model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))// Define training configurationval config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener())model.newTrainer(config).use { trainer ->trainer.initialize(Shape(1, (28 * 28).toLong()))for (epoch in 0..9) {for (batch in trainer.iterateDataset(trainDataset)) {trainer.step()batch.close()}trainer.notifyListeners { listener: TrainingListener ->listener.onEpoch(trainer)}}model.save(Paths.get("model"), "mlp")}}
    }
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DjlPredictorDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load modelval model = Model.newInstance("mlp")model.load(Paths.get("model"), "mlp")// Define Translatorval translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {return NDList(input.reshape(Shape(1, (28 * 28).toLong())))}override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {// Assuming the output NDArray is the first element in NDListval probabilities = list.singletonOrThrow()return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels}override fun getBatchifier(): Batchifier {return Batchifier.STACK}}model.newPredictor(translator).use { predictor ->val manager = NDManager.newBaseManager()val array = manager.ones(Shape(1, (28 * 28).toLong()))val classifications = predictor.predict(array)println(classifications)}}
    }

这篇关于java(kotlin) ai框架djl的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056911

相关文章

springboot集成easypoi导出word换行处理过程

《springboot集成easypoi导出word换行处理过程》SpringBoot集成Easypoi导出Word时,换行符n失效显示为空格,解决方法包括生成段落或替换模板中n为回车,同时需确... 目录项目场景问题描述解决方案第一种:生成段落的方式第二种:替换模板的情况,换行符替换成回车总结项目场景s

SpringBoot集成redisson实现延时队列教程

《SpringBoot集成redisson实现延时队列教程》文章介绍了使用Redisson实现延迟队列的完整步骤,包括依赖导入、Redis配置、工具类封装、业务枚举定义、执行器实现、Bean创建、消费... 目录1、先给项目导入Redisson依赖2、配置redis3、创建 RedissonConfig 配

SpringBoot中@Value注入静态变量方式

《SpringBoot中@Value注入静态变量方式》SpringBoot中静态变量无法直接用@Value注入,需通过setter方法,@Value(${})从属性文件获取值,@Value(#{})用... 目录项目场景解决方案注解说明1、@Value("${}")使用示例2、@Value("#{}"php

SpringBoot分段处理List集合多线程批量插入数据方式

《SpringBoot分段处理List集合多线程批量插入数据方式》文章介绍如何处理大数据量List批量插入数据库的优化方案:通过拆分List并分配独立线程处理,结合Spring线程池与异步方法提升效率... 目录项目场景解决方案1.实体类2.Mapper3.spring容器注入线程池bejsan对象4.创建

线上Java OOM问题定位与解决方案超详细解析

《线上JavaOOM问题定位与解决方案超详细解析》OOM是JVM抛出的错误,表示内存分配失败,:本文主要介绍线上JavaOOM问题定位与解决方案的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一、OOM问题核心认知1.1 OOM定义与技术定位1.2 OOM常见类型及技术特征二、OOM问题定位工具

基于 Cursor 开发 Spring Boot 项目详细攻略

《基于Cursor开发SpringBoot项目详细攻略》Cursor是集成GPT4、Claude3.5等LLM的VSCode类AI编程工具,支持SpringBoot项目开发全流程,涵盖环境配... 目录cursor是什么?基于 Cursor 开发 Spring Boot 项目完整指南1. 环境准备2. 创建

Spring Security简介、使用与最佳实践

《SpringSecurity简介、使用与最佳实践》SpringSecurity是一个能够为基于Spring的企业应用系统提供声明式的安全访问控制解决方案的安全框架,本文给大家介绍SpringSec... 目录一、如何理解 Spring Security?—— 核心思想二、如何在 Java 项目中使用?——

SpringBoot+RustFS 实现文件切片极速上传的实例代码

《SpringBoot+RustFS实现文件切片极速上传的实例代码》本文介绍利用SpringBoot和RustFS构建高性能文件切片上传系统,实现大文件秒传、断点续传和分片上传等功能,具有一定的参考... 目录一、为什么选择 RustFS + SpringBoot?二、环境准备与部署2.1 安装 RustF

springboot中使用okhttp3的小结

《springboot中使用okhttp3的小结》OkHttp3是一个JavaHTTP客户端,可以处理各种请求类型,比如GET、POST、PUT等,并且支持高效的HTTP连接池、请求和响应缓存、以及异... 在 Spring Boot 项目中使用 OkHttp3 进行 HTTP 请求是一个高效且流行的方式。

java.sql.SQLTransientConnectionException连接超时异常原因及解决方案

《java.sql.SQLTransientConnectionException连接超时异常原因及解决方案》:本文主要介绍java.sql.SQLTransientConnectionExcep... 目录一、引言二、异常信息分析三、可能的原因3.1 连接池配置不合理3.2 数据库负载过高3.3 连接泄漏