1
- # 用PyTorch和DJL做涂鸦识别的小游戏
1
+ # 用 Java 做个“你画手机猜”的小游戏
2
2
3
3
> 本文适合有 Java 基础的人群
4
4
5
- ![ ] ( https://raw.githubusercontent.com/aws-samples/djl-demo/master/android/img/rabbit .gif)
5
+ ![ ] ( 0 .gif)
6
6
7
7
作者:** DJL-Lanking**
8
8
9
9
HelloGitHub 推出的[ 《讲解开源项目》] ( https://github.com/HelloGitHub-Team/Article ) 系列。有幸邀请到了亚马逊 + Apache 的工程师:Lanking( https://github.com/lanking520 ),为我们讲解 DJL —— 完全由 Java 构建的深度学习平台,本文为系列的第三篇。
10
10
11
11
## 一、前言
12
12
13
- 在2018年时 ,Google 推出了《猜画小歌》应用:玩家可以直接与AI进行你画我猜的游戏 。通过画出一个房子或者一个猫,AI 会推断出各种物品被画出的概率。它的实现得益于深度学习模型在其中的应用,通过深度神经网络的归纳,曾经令人头疼的绘画识别也变得易如反掌。现如今,只要使用一个简单的图片分类模型,我们便可以轻松的实现绘画识别。试试看这个在线涂鸦小游戏吧:
13
+ 在 2018 年时 ,Google 推出了《猜画小歌》应用:玩家可以直接与 AI 进行你画我猜的游戏 。通过画出一个房子或者一个猫,AI 会推断出各种物品被画出的概率。它的实现得益于深度学习模型在其中的应用,通过深度神经网络的归纳,曾经令人头疼的绘画识别也变得易如反掌。现如今,只要使用一个简单的图片分类模型,我们便可以轻松的实现绘画识别。试试看这个在线涂鸦小游戏吧:
14
14
15
15
> 在线涂鸦小游戏:https://djl.ai/website/demo.html#doodle
16
16
17
- 在当时,大部分机器学习计算任务仍旧需要依托网络在云端进行。随着算力的不断增进,机器学习任务已经可以直接在边缘设备部署,包括各类运行安卓系统的智能手机。但是,由于安卓本身主要是用 Java ,部署基于 Python 的各类深度学习模型变成了一个难题。为了解决这个问题,AWS开发并开源了 DeepJavaLibrary (DJL),一个为 Java 量身定制的深度学习框架。
17
+ 在当时,大部分机器学习计算任务仍旧需要依托网络在云端进行。随着算力的不断增进,机器学习任务已经可以直接在边缘设备部署,包括各类运行安卓系统的智能手机。但是,由于安卓本身主要是用 Java ,部署基于 Python 的各类深度学习模型变成了一个难题。为了解决这个问题,AWS 开发并开源了 DeepJavaLibrary (DJL),一个为 Java 量身定制的深度学习框架。
18
18
19
- 在这个文章中,我们将尝试通过PyTorch预训练模型在在安卓平台构建一个涂鸦绘画的应用 。由于总代码量会比较多,我们这次会挑重点把最关键的代码完成。你可以后续参考我们完整的项目进行构建。
19
+ 在这个文章中,我们将尝试通过 PyTorch 预训练模型在安卓平台构建一个涂鸦绘画的应用 。由于总代码量会比较多,我们这次会挑重点把最关键的代码完成。你可以后续参考我们完整的项目进行构建。
20
20
21
21
> 涂鸦应用完整代码:https://github.com/aws-samples/djl-demo/tree/master/android
22
22
@@ -41,27 +41,27 @@ dependencies {
41
41
42
42
## 三、构建应用
43
43
44
- ### 第一步:创建Layout
44
+ ### 3.1 第一步:创建 Layout
45
45
46
- 我们可以先创建一个 View class 以及 layout (如下图)来构建安卓的前端显示界面。
46
+ 我们可以先创建一个 View class 以及 layout(如下图)来构建安卓的前端显示界面。
47
47
48
48
![ ] ( 1.png )
49
49
50
- 如上图所示,你可以在主界面创建两个 ` View ` 目标。 ` PaintView ` 是用来让用户画画的,在右下角 ` ImageView ` 是用来展示用于深度学习推理的图像。同时我们预留一个按钮来进行画板的清空操作。
50
+ 如上图所示,你可以在主界面创建两个 ` View ` 目标。` PaintView ` 是用来让用户画画的,在右下角 ` ImageView ` 是用来展示用于深度学习推理的图像。同时我们预留一个按钮来进行画板的清空操作。
51
51
52
- ### 第二步: 应对绘画动作
52
+ ### 3.2 第二步: 应对绘画动作
53
53
54
54
在安卓设备上,你可以自定义安卓的触摸事件响应来应对用户的各种触控操作。在我们的情况下,我们需要定义下面三种时间响应:
55
55
56
- * touchStart: 感应触碰时触发
57
- * touchMove: 当用户在屏幕上移动手指时触发
58
- * touchUp: 当用户抬起手指时触发
56
+ * touchStart: 感应触碰时触发
57
+ * touchMove: 当用户在屏幕上移动手指时触发
58
+ * touchUp: 当用户抬起手指时触发
59
59
60
60
与此同时,我们用 paths 来存储用户在画板所绘制的路径。现在我们看一下实现代码。
61
61
62
- #### 重写 ` OnTouchEvent ` 和 ` OnDraw ` 方法
62
+ #### 3.2.1 重写 ` OnTouchEvent ` 和 ` OnDraw ` 方法
63
63
64
- 现在我们重写 ` onTouchEvent ` 来应对各种响应:
64
+ 现在我们重写 ` onTouchEvent ` 来应对各种响应:
65
65
66
66
``` java
67
67
@Override
@@ -91,7 +91,7 @@ public boolean onTouchEvent(MotionEvent event) {
91
91
92
92
如上面代码所示,你可以添加一个 ` runInference ` 方法在 ` MotionEvent.ACTION_UP ` 事件响应上。这个方法是用来在用户绘制完后对结果进行推理。在之后的几步中,我们会讲解它的具体实现。
93
93
94
- 我们同样需要重写 ` onDraw ` 方法来展示用户绘制的图像:
94
+ 我们同样需要重写 ` onDraw ` 方法来展示用户绘制的图像:
95
95
96
96
``` java
97
97
@Override
@@ -111,7 +111,7 @@ protected void onDraw(Canvas canvas) {
111
111
112
112
真正的图像会保存在一个 ` Bitmap ` 上。
113
113
114
- #### touchStart
114
+ #### 3.2.2 操作开始( touchStart)
115
115
116
116
当用户触碰行为开始时,下面的代码会建立一个新的路径同时记录路径中每一个点在屏幕上的坐标。
117
117
@@ -126,11 +126,11 @@ private void touchStart(float x, float y) {
126
126
}
127
127
```
128
128
129
- #### touchMove
129
+ #### 3.2.3 手指移动( touchMove)
130
130
131
131
在手指移动中,我们会持续记录坐标点然后将它们构成一个 quadratic bezier. 通过一定的误差阀值来动态优化用户的绘画动作。只有差别超出误差范围内的动作才会被记录下来。
132
132
133
- > quadratic bezier 文档: https://developer.android.com/reference/android/graphics/Path#quadTo(float,%20float,%20float,%20float)
133
+ > quadratic bezier 文档: https://developer.android.com/reference/android/graphics/Path
134
134
135
135
``` java
136
136
private void touchMove(float x, float y) {
@@ -148,7 +148,7 @@ private void touchMove(float x, float y) {
148
148
}
149
149
```
150
150
151
- #### touchUp
151
+ #### 3.2.4 操作结束( touchUp)
152
152
153
153
当触控操作结束后,下面的代码会绘制一个路径同时计算最小长方形目标框。
154
154
@@ -159,31 +159,31 @@ private void touchUp() {
159
159
}
160
160
```
161
161
162
- ### 第三步: 开始推理
162
+ ### 3.3 第三步: 开始推理
163
163
164
164
为了在安卓设备上进行推理任务,我们需要完成下面几个任务:
165
165
166
- * 从URL读取模型
166
+ * 从 URL 读取模型
167
167
* 构建前处理和后处理过程
168
168
* 从 PaintView 进行推理任务
169
169
170
- 为了完成以下目标,我们尝试构建一个 ` DoodleModel ` class 。在这一步,我们将介绍一些完成这些任务的关键步骤。
170
+ 为了完成以下目标,我们尝试构建一个 ` DoodleModel ` class。在这一步,我们将介绍一些完成这些任务的关键步骤。
171
171
172
- #### 读取模型
172
+ #### 3.3.1 读取模型
173
173
174
- DJL内建了一套模型管理系统 。开发者可以自定义储存模型的文件夹。
174
+ DJL 内建了一套模型管理系统 。开发者可以自定义储存模型的文件夹。
175
175
176
176
``` java
177
177
File dir = getFilesDir();
178
178
System . setProperty(" DJL_CACHE_DIR" , dir. getAbsolutePath());
179
179
```
180
180
181
- 通过更改 ` DJL_CACHE_DIR ` 属性, 模型会被存入相应路径下.
181
+ 通过更改 ` DJL_CACHE_DIR ` 属性, 模型会被存入相应路径下。
182
182
183
- 下一步可以通过定义 Criteria 从指定 URL 处下载模型. 下载的 zip 文件内包含:
183
+ 下一步可以通过定义 Criteria 从指定 URL 处下载模型。 下载的 zip 文件内包含:
184
184
185
- * doodle_mobilenet.pt: PyTorch 模型
186
- * synset.txt: 包含分类任务中所有类别的名称
185
+ * ` doodle_mobilenet.pt ` : PyTorch 模型
186
+ * ` synset.txt ` : 包含分类任务中所有类别的名称
187
187
188
188
``` java
189
189
Criteria<Image , Classifications > criteria =
@@ -194,9 +194,9 @@ Criteria<Image, Classifications> criteria =
194
194
.build();
195
195
return ModelZoo . loadModel(criteria);
196
196
```
197
+ 上述代码同时定义了 translator,它会被用来做图片的前处理和后处理。
197
198
198
- 上述代码同时定义了 translator 。translator 会被用来做图片的前处理和后处理。
199
- 最后,如下述代码创建一个 ` Model ` 并用它创建一个 ` Predictor ` :
199
+ 最后,如下述代码创建一个 ` Model ` 并用它创建一个 ` Predictor ` :
200
200
201
201
``` java
202
202
@Override
@@ -212,13 +212,13 @@ protected Boolean doInBackground(Void... params) {
212
212
}
213
213
```
214
214
215
- 更多关于模型加载的信息,请参阅如何加载模型.
215
+ 更多关于模型加载的信息,请参阅如何加载模型。
216
216
217
- > DJL模型加载文档: http://docs.djl.ai/docs/load_model.html
217
+ > DJL 模型加载文档: http://docs.djl.ai/docs/load_model.html
218
218
219
- #### 用Translator定义前处理和后处理
219
+ #### 3.3.2 用 Translator 定义前处理和后处理
220
220
221
- 在DJL中, 我们定义了 Translator 接口进行前处理和后处理。在 DoodleModel 中我们定义了 ImageClassificationTranslator 来实现 Translator:
221
+ 在 DJL 中, 我们定义了 Translator 接口进行前处理和后处理。在 DoodleModel 中我们定义了 ImageClassificationTranslator 来实现 Translator:
222
222
223
223
``` java
224
224
ImageClassificationTranslator . builder()
@@ -227,12 +227,11 @@ ImageClassificationTranslator.builder()
227
227
.optApplySoftmax(true ). build());
228
228
```
229
229
230
- 下面我们详细阐述 translator 所定义的前处理和后处理如何被用在模型的推理步骤中。当你创建 translator 时,内部程序会自动加载 ` synset.txt ` 文件得到做分类任务时所有类别的名称。当模型的predict()方法被调用时,内部程序会先执行所对应的 translator 的前处理步骤,而后执行实际推理步骤,最后执行 translator 的后处理步骤。对于前处理,我们会将 Image转化NDArray,用于作为模型推理过程的输入。对于后处理,我们对推理输出的结果(NDArray)进行 softmax 操作。最终返回结果为 Classifications 的一个实例。
231
- 更多关于 translator 的工作原理以及如何个性化 Translator 的信息,请参阅 Inference with your model。
230
+ 下面我们详细阐述 translator 所定义的前处理和后处理如何被用在模型的推理步骤中。当你创建 translator 时,内部程序会自动加载 ` synset.txt ` 文件得到做分类任务时所有类别的名称。当模型的 ` predict() ` 方法被调用时,内部程序会先执行所对应的 translator 的前处理步骤,而后执行实际推理步骤,最后执行 translator 的后处理步骤。对于前处理,我们会将 Image 转化 NDArray,用于作为模型推理过程的输入。对于后处理,我们对推理输出的结果(NDArray)进行 softmax 操作。最终返回结果为 Classifications 的一个实例。
232
231
233
- > 自定义Translator案例: http://docs.djl.ai/jupyter/pytorch/load_your_own_pytorch_bert.html
232
+ > 自定义 Translator 案例: http://docs.djl.ai/jupyter/pytorch/load_your_own_pytorch_bert.html
234
233
235
- #### 用 PaintView 进行推理任务
234
+ #### 3.3.3 用 PaintView 进行推理任务
236
235
237
236
最后,我们来实现之前定义好的 runInference 方法。
238
237
@@ -262,13 +261,13 @@ public void runInference() {
262
261
263
262
恭喜你!我们完成了一个涂鸦识别小程序!
264
263
265
- ### 可选优化: 输入裁剪
264
+ ### 3.4 可选优化: 输入裁剪
266
265
267
266
为了得到更高的模型推理准确度,你可以通过截取图像来去除无意义的边框部分。
268
267
269
268
![ ] ( 3.png )
270
269
271
- 上面右侧的图片会比左边的图片有更好的推理结果,因为它所包含的空白边框更少。你可以通过 Bound 类来寻找图片的有效边界,即能把图中所有白色像素点覆盖的最小矩形。在得到x轴最左坐标,y轴最上坐标 ,以及矩形高度和宽度后,就可以用这些信息截取出我们想要的图形(如右图所示)实现代码如下:
270
+ 上面右侧的图片会比左边的图片有更好的推理结果,因为它所包含的空白边框更少。你可以通过 Bound 类来寻找图片的有效边界,即能把图中所有白色像素点覆盖的最小矩形。在得到 x 轴最左坐标,y 轴最上坐标 ,以及矩形高度和宽度后,就可以用这些信息截取出我们想要的图形(如右图所示)实现代码如下:
272
271
273
272
``` java
274
273
RectF bound = maxBound. getBound();
@@ -294,6 +293,4 @@ Deep Java Library (DJL) 是一个基于 Java 的深度学习框架,同时支
294
293
295
294
它同时拥有着强大的模型库支持:只需一行便可以轻松读取各种预训练的模型。现在 DJL 的模型库同时支持高达 70 个来自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。
296
295
297
- > 项目地址:https://github.com/awslabs/djl/
298
-
299
- 在最新的版本中 DJL 0.7.0 添加了对于 MXNet 1.7.0、PyTorch 1.6.0、TensorFlow 2.3.0 的支持。我们同时也添加了 ONNXRuntime 以及 PyTorch 在安卓平台的支持。
296
+ > 项目地址:https://github.com/awslabs/djl/
0 commit comments