金沙js0888_首頁(欢迎您)

  • <td id="dgejl"><strike id="dgejl"></strike></td>
        1. DJL 如何正确打开 [ 深度学习 ]

          作者: 小明菜市场 2020-11-04 07:10:42

           

          本文转载自微信公众号「小明菜市场」,作者小明菜市场。转载本文请联系小明菜市场公众号。

          前言

          很长时间,Java都是一个相当受欢迎的企业编程语言,其框架丰富,生态完善。Java拥有庞大的开发者社区,尽管深度学习应用不断推进和演化,但是相关的深度学习框架对于Java来说相当的稀少,现如今,主要模型都是Python编译和训练,对于Java开发者来说,如果想要学习深度学习,就需要接受一门新的语言的洗礼。为了减少Java开发者学习深度学习的成本,AWS构建了一个Deep Java Library(DJL),一个为Java开发者定制的开源深度学习框架,其为开发者对接主流深度学习框架,提供了一个接口。

          什么是深度学习

          在开始之前,先了解机器学习和深度学习基础概念。机器学习是一个利用统计学知识,把数据输入到计算机中进行训练并完成特定目标任务的过程,这种归纳学习方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体。深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发,人工神经网络是通过研究人脑如何学习和实现目标的过程中,归纳出的一套计算逻辑。通过模拟部分人脑神经间信息传递的过程,从而实现各种复杂的任务,深度学习中的深度来源于会在人工神经网络中编制出,构建出许多层,从而进一步对数据信息进行更为深层次的传导。

          训练 MNIST 手写数字识别

          项目配置

          利用 gradle 配置引入依赖包,用DJL的api包和basicdataset包来构建神经网络和数据集,这个案例,使用 MXNet作为深度学习引擎,所以引入mxnet-engine和mxnet-native-auto两个包,依赖如下

          1. plugins { 
          2.     id 'java' 
          3. repositories {                            
          4.     jcenter() 
          5. dependencies { 
          6.     implementation platform("ai.djl:bom:0.8.0"
          7.     implementation "ai.djl:api" 
          8.     implementation "ai.djl:basicdataset" 
          9.     // MXNet 
          10.     runtimeOnly "ai.djl.mxnet:mxnet-engine" 
          11.     runtimeOnly "ai.djl.mxnet:mxnet-native-auto" 

          NDArry 和 NDManager

          NDArray 是 DJL 存储数据结构和数学运算的基本结构,一个NDArry表达了一个定长的多维数组,NDArry的使用方法,类似于Python的numpy.ndarry。NDManager是NDArry的管理者,其负责管理NDArry的产生和回收过程,这样可以帮助我们更好的对Java内存进行优化,每一个NDArry都会由一个NDManager创造出来,同时他们会在NDManager关闭时一同关闭,

          Model

          在 DJL 中,训练和推理都是从 Model class 开始构建的,我们在这里主要训练过程中的构建方法,下面我们为 Model 创建一个新的目标,因为 Model 也是继承了 AutoClosable 结构体,用一个 try block实现。

          1. try (Model model = Model.newInstance()) { 
          2.     ... 
          3.     // 主体训练代码 
          4.     ... 

          准备数据

          MNIST 数据库包含大量的手写数字的图,通常用来训练图像处理系统,DJL已经把MNIST的数据收集到了 basicdataset 数据里,每个 MNIST 的图的大小是 28 * 28, 如果有自己的数据集,同样可以使用同理来收集数据。

          数据集导入教程 https://docs.djl.ai/docs/development/how_to_use_dataset.html#how-to-create-your-own-dataset

          1. int batchSize = 32; // 批大小 
          2. Mnist trainingDataset = Mnist.builder() 
          3.         .optUsage(Usage.TRAIN) // 训练集 
          4.         .setSampling(batchSize, true
          5.         .build(); 
          6. Mnist validationDataset = Mnist.builder() 
          7.         .optUsage(Usage.TEST) // 验证集 
          8.         .setSampling(batchSize, true
          9.         .build(); 

          这段代码分别制作了训练和验证集,同时我们也随机的排列了数据集从而更好的训练,除了这些配置以外,也可以对图片进行进一步的设置,例如设置图片大小,归一化处理。

          制作 model 建立 block

          当数据集准备就绪以后,就可以构建神经网络,在DJL 中,神经网络是由 Block 代码块构成的,一个Block是一个具备多种神经网络特性的结构,他们可以代表一个操作神经网络的一部分,甚至一个完整的神经网络,然后 block 就可以顺序的执行或者并行。同时 block 本身也可以带参数和子block,这种嵌套结构可以快速的帮助更新一个可维护的神经网络,在训练过程中,每个block附带参数也会实时更新,同时也会更新其子 block。当我们构建这些 block 的过程中,最简单的方式就是把他们一个一个嵌套起来,直接使用准备好的 DJL的 Block 种类,我们就可以快速制作各种神经网络。

          block 变体

          根据几种基本的神经网络工作模式,我们提供几种Block的变体,

          1. SequentialBlock 是为了输出作为下一个block的输入继续执行到底。
          2. parallelblock 是用于将一个输入并行输入到每一个子block中,同时也将输出结果根据特定的合并方程合并起来。
          3. lambdablock 是帮助用户进行快速操作的一个block,其中不具备任何参数,所以在训练的过程中没有任何部分在训练过程中更新。

          构建多层感知机 MLP 神经网络

          我们构建一个简单的多层感知机神经网络,多层感知机是一个简单的前向型神经网络,只包含几个全连接层,构建这个网路可以直接使用 sequentialblock

          1. int input = 28 * 28; // 输入层大小 
          2. int output = 10; // 输出层大小 
          3. int[] hidden = new int[] {128, 64}; // 隐藏层大小 
          4. SequentialBlock sequentialBlock = new SequentialBlock(); 
          5. sequentialBlock.add(Blocks.batchFlattenBlock(input)); 
          6. for (int hiddenSize : hidden) { 
          7.     // 全连接层 
          8.     sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build()); 
          9.     // 激活函数 
          10.     sequentialBlock.add(activation); 
          11. sequentialBlock.add(Linear.builder().setUnits(output).build()); 

          可以使用直接提供好的 MLP Block

          1. Block block = new Mlp( 
          2.         Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH, 
          3.         Mnist.NUM_CLASSES, 
          4.         new int[] {128, 64}); 

          训练

          使用如下几个步骤,

          完成一个训练过程初始化:我们会对每一个Block的参数进行初始化,初始化每个参数的函数都是由设定的 initializer决定的。前向传播:这一步把输入数据在神经网络中逐层传递,然后产生输出数据。计算损失:我们会根据特定的损失函数 loss 来计算输出和标记结果的偏差。反向传播:在这一步中,利用损失反向求导计算出每一个参数的梯度。更新权重,会根据选择的优化器,更新每一个在 Block 上的参数的值。

          精简

          DJL 利用了 Trainer 结构体精简了整个过程,开发者只需要创建Trainer 并指定对应的initializer,loss,optimizer即可,这些参数都是由TrainingConfig设定,来看参数的设置。TrainingListener 训练过程设定的监听器,可以实时反馈每个阶段的训练结果,这些结果可以用于记录训练过程或者帮助 debug 神经网络训练过程中遇到的问题。用户可以定制自己的 TrainingListener 来训练过程进行监听

          1. DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) 
          2.     .addEvaluator(new Accuracy()) 
          3.     .addTrainingListeners(TrainingListener.Defaults.logging()); 
          4. try (Trainer trainer = model.newTrainer(config)){ 
          5.     // 训练代码 

          训练产生以后,可以定义输入的 Shape,之后可以调用 git函数进行训练,结果会保存在本地目录下

          1. /* 
          2.  * MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray。 
          3.  * 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化。 
          4.  */ 
          5. Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH); 
          6. int numEpoch = 5; 
          7. String outputDir = "/build/model"
          8.  
          9. // 用输入初始化 trainer 
          10. trainer.initialize(inputShape); 
          11.  
          12. TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp"); 

          输出的结果图

          1. [INFO ] - Downloading libmxnet.dylib ... 
          2. [INFO ] - Training on: cpu(). 
          3. [INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms. 
          4. Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec 
          5. Validating:  100% |████████████████████████████████████████| 
          6. [INFO ] - Epoch 1 finished. 
          7. [INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24 
          8. [INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14 
          9. Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec 
          10. Validating:  100% |████████████████████████████████████████| 
          11. [INFO ] - Epoch 2 finished.NG [1m 41s] 
          12. [INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10 
          13. [INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09 
          14. [INFO ] - train P50: 12.756 ms, P90: 21.044 ms 
          15. [INFO ] - forward P50: 0.375 ms, P90: 0.607 ms 
          16. [INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms 
          17. [INFO ] - backward P50: 0.608 ms, P90: 0.973 ms 
          18. [INFO ] - step P50: 0.543 ms, P90: 0.869 ms 
          19. [INFO ] - epoch P50: 35.989 s, P90: 35.989 s 

          训练结束以后,就可以对模型进行识别了和使用了。

          关于作者

          我是小小,一个生于二线城市活在一线城市的小小,本期结束,我们下期再见。

          深度学习 DJL 语言
          上一篇:本周六锁定成都!解析百度文心(ERNIE)如何助力快速定制企业级NLP模型 下一篇:人在旅途不太囧,国内景区部署AI解锁智慧游玩新场景
          评论
          取消
          暂无评论,快去成为第一个评论的人吧

          更多资讯推荐

          人工智能的三个必要条件

          2016年,AlphaGo下围棋战胜李世乭,大家都认为人工智能的时代到来了。人工智能也是同样的在一定的历史契机下,几个独立发展的领域碰巧合并在一起就产生了巨大的推动力。这一波人工智能发展的三个必要条件是:深度学习模型,大数据,算力(并行计算)。

          麦教授说 ·? 22h前
          企业在应用人工智能时不可不知的5个误区

          研究表明,70%以上的企业如今将人工智能视为游戏规则的改变者。然而,目前使用人工智能或计划很快使用人工智能的企业不到40%。人们对人工智能重要性的认识差距仍然很大,以下是五个关于人工智能的常见误区或误解。

          Geertrui Mieke ·? 2天前
          Java为什么不能真正支持机器/深度学习?到底还欠缺了什么

          自1998年以来,就多个企业的变革而言,Java一直处于领先地位 - 网络,移动,浏览器与原生,消息传递,i18n和l10n全球化支持,扩展和支持各种企业信息存储值得一提的是,从关系数据库到Elasticsearch。

          佚名 ·? 3天前
          人工智能简史|(2)深度学习,掀起人工智能的新高潮

          刘韩 ·? 4天前
          人工智能简史|(1)深度学习,掀起人工智能的新高潮

          也许你觉得人工智能离你还有点远,只存在于谷歌那巨大无比的数据中心机房,或者充满神秘感的麻省理工学院机器人实验室。其实,透过互联网和智能手机,人工智能已经开始渗入我们每天的日常生活。

          刘韩 ·? 4天前
          基于深度学习的图像匹配技术一览

          佚名 ·? 4天前
          深度学习:突破新兴技术的边界

          从大数据到AI,几乎每个正在发展的技术分支都受益于深度学习的深刻价值。

          千家网 ·? 2021-04-01 13:53:26
          高空抛物悲剧频出,AI 监控系统:让我来「罩」着你

          高空抛物是现代社会的一大顽疾,一直以来都缺乏有效监控手段。安防公司借助人工智能技术,利用视频与图像分析,给出了监管高空抛物的解决方案。

          佚名 ·? 2021-03-30 22:41:30
          Copyright?2005-2021 51CTO.COM 版权所有 未经许可 请勿转载
          金沙js0888