Skip to content

Commit 477e439

Browse files
authored
Upgrade to djl 0.13.0 (#176)
1 parent 011c94a commit 477e439

File tree

88 files changed

+766
-1291
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+766
-1291
lines changed

chapter_attention-mechanisms/attention-cues.ipynb

+1-8
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@
183183
"metadata": {},
184184
"outputs": [],
185185
"source": [
186-
"NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));"
186+
"NDManager manager = NDManager.newBaseManager();"
187187
]
188188
},
189189
{
@@ -297,13 +297,6 @@
297297
"1. What can be the volitional cue when decoding a sequence token by token in machine translation? What are the nonvolitional cues and the sensory inputs?\n",
298298
"1. Randomly generate a $10 \\times 10$ matrix and use the softmax operation to ensure each row is a valid probability distribution. Visualize the output attention weights.\n"
299299
]
300-
},
301-
{
302-
"cell_type": "code",
303-
"execution_count": null,
304-
"metadata": {},
305-
"outputs": [],
306-
"source": []
307300
}
308301
],
309302
"metadata": {

chapter_attention-mechanisms/attention-scoring-functions.ipynb

+72-79
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
"metadata": {},
9393
"outputs": [],
9494
"source": [
95-
"NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));"
95+
"NDManager manager = NDManager.newBaseManager();"
9696
]
9797
},
9898
{
@@ -136,20 +136,19 @@
136136
" // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray\n",
137137
" if (validLens == null) {\n",
138138
" return X.softmax(-1);\n",
139+
" }\n",
140+
" \n",
141+
" Shape shape = X.getShape();\n",
142+
" if (validLens.getShape().dimension() == 1) {\n",
143+
" validLens = validLens.repeat(shape.get(1));\n",
139144
" } else {\n",
140-
" Shape shape = X.getShape();\n",
141-
" if (validLens.getShape().dimension() == 1) {\n",
142-
" validLens = validLens.repeat(shape.get(1));\n",
143-
" } else {\n",
144-
" validLens = validLens.reshape(-1);\n",
145-
" }\n",
146-
" // On the last axis, replace masked elements with a very large negative\n",
147-
" // value, whose exponentiation outputs 0\n",
148-
" X =\n",
149-
" X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n",
150-
" .sequenceMask(validLens, (float) -1E6);\n",
151-
" return X.softmax(-1).reshape(shape);\n",
145+
" validLens = validLens.reshape(-1);\n",
152146
" }\n",
147+
" // On the last axis, replace masked elements with a very large negative\n",
148+
" // value, whose exponentiation outputs 0\n",
149+
" X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n",
150+
" .sequenceMask(validLens, (float) -1E6);\n",
151+
" return X.softmax(-1).reshape(shape);\n",
153152
"}"
154153
]
155154
},
@@ -174,10 +173,9 @@
174173
"metadata": {},
175174
"outputs": [],
176175
"source": [
177-
"System.out.println(\n",
178-
" maskedSoftmax(\n",
179-
" manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
180-
" manager.create(new float[] {2, 3})));"
176+
"maskedSoftmax(\n",
177+
" manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
178+
" manager.create(new float[] {2, 3}));"
181179
]
182180
},
183181
{
@@ -198,10 +196,9 @@
198196
"metadata": {},
199197
"outputs": [],
200198
"source": [
201-
"System.out.println(\n",
202-
" maskedSoftmax(\n",
203-
" manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
204-
" manager.create(new float[][] {{1, 3}, {2, 4}})));"
199+
"maskedSoftmax(\n",
200+
" manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n",
201+
" manager.create(new float[][] {{1, 3}, {2, 4}}));"
205202
]
206203
},
207204
{
@@ -243,32 +240,31 @@
243240
"outputs": [],
244241
"source": [
245242
"/* Additive attention. */\n",
246-
"public class AdditiveAttention extends AbstractBlock {\n",
247-
" private static final byte VERSION = 1;\n",
243+
"public static class AdditiveAttention extends AbstractBlock {\n",
244+
"\n",
248245
" private Linear W_k;\n",
249246
" private Linear W_q;\n",
250247
" private Linear W_v;\n",
251248
" private Dropout dropout;\n",
252249
" public NDArray attentionWeights;\n",
253250
"\n",
254251
" public AdditiveAttention(int numHiddens, float dropout) {\n",
255-
" super(VERSION);\n",
256-
" this.W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
257-
" this.addChildBlock(\"W_k\", this.W_k);\n",
252+
" W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
253+
" addChildBlock(\"W_k\", W_k);\n",
258254
"\n",
259-
" this.W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
260-
" this.addChildBlock(\"W_q\", this.W_q);\n",
255+
" W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n",
256+
" addChildBlock(\"W_q\", W_q);\n",
261257
"\n",
262-
" this.W_v = Linear.builder().setUnits(1).optBias(false).build();\n",
263-
" this.addChildBlock(\"W_v\", this.W_v);\n",
258+
" W_v = Linear.builder().setUnits(1).optBias(false).build();\n",
259+
" addChildBlock(\"W_v\", W_v);\n",
264260
"\n",
265261
" this.dropout = Dropout.builder().optRate(dropout).build();\n",
266-
" this.addChildBlock(\"dropout\", this.dropout);\n",
262+
" addChildBlock(\"dropout\", this.dropout);\n",
267263
" }\n",
268264
"\n",
269265
" @Override\n",
270266
" protected NDList forwardInternal(\n",
271-
" ParameterStore parameterStore,\n",
267+
" ParameterStore ps,\n",
272268
" NDList inputs,\n",
273269
" boolean training,\n",
274270
" PairList<String, Object> params) {\n",
@@ -279,8 +275,8 @@
279275
" NDArray values = inputs.get(2);\n",
280276
" NDArray validLens = inputs.get(3);\n",
281277
"\n",
282-
" queries = this.W_q.forward(parameterStore, new NDList(queries), training, params).head();\n",
283-
" keys = this.W_k.forward(parameterStore, new NDList(keys), training, params).head();\n",
278+
" queries = W_q.forward(ps, new NDList(queries), training, params).head();\n",
279+
" keys = W_k.forward(ps, new NDList(keys), training, params).head();\n",
284280
" // After dimension expansion, shape of `queries`: (`batchSize`, no. of\n",
285281
" // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,\n",
286282
" // no. of key-value pairs, `numHiddens`). Sum them up with\n",
@@ -290,18 +286,12 @@
290286
" // There is only one output of `this.W_v`, so we remove the last\n",
291287
" // one-dimensional entry from the shape. Shape of `scores`:\n",
292288
" // (`batchSize`, no. of queries, no. of key-value pairs)\n",
293-
" NDArray result =\n",
294-
" this.W_v.forward(parameterStore, new NDList(features), training, params).head();\n",
289+
" NDArray result = W_v.forward(ps, new NDList(features), training, params).head();\n",
295290
" NDArray scores = result.squeeze(-1);\n",
296-
" this.attentionWeights = maskedSoftmax(scores, validLens);\n",
297-
" // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n",
298-
" // dimension)\n",
299-
" return new NDList(\n",
300-
" this.dropout\n",
301-
" .forward(\n",
302-
" parameterStore, new NDList(this.attentionWeights), training, params)\n",
303-
" .head()\n",
304-
" .batchDot(values));\n",
291+
" attentionWeights = maskedSoftmax(scores, validLens);\n",
292+
" // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)\n",
293+
" NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n",
294+
" return new NDList(list.head().batchDot(values));\n",
305295
" }\n",
306296
"\n",
307297
" @Override\n",
@@ -310,8 +300,21 @@
310300
" }\n",
311301
"\n",
312302
" @Override\n",
313-
" public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n",
314-
"}"
303+
" public void initializeChildBlocks(\n",
304+
" NDManager manager, DataType dataType, Shape... inputShapes) {\n",
305+
" W_q.initialize(manager, dataType, inputShapes[0]);\n",
306+
" W_k.initialize(manager, dataType, inputShapes[1]);\n",
307+
" long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();\n",
308+
" long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();\n",
309+
" long w = Math.max(q[q.length - 2], k[k.length - 2]);\n",
310+
" long h = Math.max(q[q.length - 1], k[k.length - 1]);\n",
311+
" long[] shape = new long[] {2, 1, w, h};\n",
312+
" W_v.initialize(manager, dataType, new Shape(shape));\n",
313+
" long[] dropoutShape = new long[shape.length - 1];\n",
314+
" System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);\n",
315+
" dropout.initialize(manager, dataType, new Shape(dropoutShape));\n",
316+
" }\n",
317+
"}\n"
315318
]
316319
},
317320
{
@@ -343,12 +346,10 @@
343346
"NDArray validLens = manager.create(new float[] {2, 6});\n",
344347
"\n",
345348
"AdditiveAttention attention = new AdditiveAttention(8, 0.1f);\n",
346-
"attention\n",
347-
" .forward(\n",
348-
" new ParameterStore(manager, false),\n",
349-
" new NDList(queries, keys, values, validLens),\n",
350-
" false)\n",
351-
" .head();"
349+
"NDList input = new NDList(queries, keys, values, validLens);\n",
350+
"ParameterStore ps = new ParameterStore(manager, false);\n",
351+
"attention.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
352+
"attention.forward(ps, input, false).head();"
352353
]
353354
},
354355
{
@@ -435,22 +436,20 @@
435436
"outputs": [],
436437
"source": [
437438
"/* Scaled dot product attention. */\n",
438-
"public class DotProductAttention extends AbstractBlock {\n",
439-
" private static final byte VERSION = 1;\n",
439+
"public static class DotProductAttention extends AbstractBlock {\n",
440+
"\n",
440441
" private Dropout dropout;\n",
441442
" public NDArray attentionWeights;\n",
442443
"\n",
443444
" public DotProductAttention(float dropout) {\n",
444-
" super(VERSION);\n",
445-
"\n",
446445
" this.dropout = Dropout.builder().optRate(dropout).build();\n",
447446
" this.addChildBlock(\"dropout\", this.dropout);\n",
448447
" this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);\n",
449448
" }\n",
450449
"\n",
451450
" @Override\n",
452451
" protected NDList forwardInternal(\n",
453-
" ParameterStore parameterStore,\n",
452+
" ParameterStore ps,\n",
454453
" NDList inputs,\n",
455454
" boolean training,\n",
456455
" PairList<String, Object> params) {\n",
@@ -459,7 +458,7 @@
459458
" // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n",
460459
" // dimension)\n",
461460
" // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)\n",
462-
" NDArray queries = inputs.head();\n",
461+
" NDArray queries = inputs.get(0);\n",
463462
" NDArray keys = inputs.get(1);\n",
464463
" NDArray values = inputs.get(2);\n",
465464
" NDArray validLens = inputs.get(3);\n",
@@ -468,12 +467,8 @@
468467
" // Swap the last two dimensions of `keys` and perform batchDot\n",
469468
" NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));\n",
470469
" attentionWeights = maskedSoftmax(scores, validLens);\n",
471-
" return new NDList(\n",
472-
" this.dropout\n",
473-
" .forward(\n",
474-
" parameterStore, new NDList(this.attentionWeights), training, params)\n",
475-
" .head()\n",
476-
" .batchDot(values));\n",
470+
" NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n",
471+
" return new NDList(list.head().batchDot(values));\n",
477472
" }\n",
478473
"\n",
479474
" @Override\n",
@@ -482,7 +477,15 @@
482477
" }\n",
483478
"\n",
484479
" @Override\n",
485-
" public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n",
480+
" public void initializeChildBlocks(\n",
481+
" NDManager manager, DataType dataType, Shape... inputShapes) {\n",
482+
" try (NDManager sub = manager.newSubManager()) {\n",
483+
" NDArray queries = sub.zeros(inputShapes[0], dataType);\n",
484+
" NDArray keys = sub.zeros(inputShapes[1], dataType);\n",
485+
" NDArray scores = queries.batchDot(keys.swapAxes(1, 2));\n",
486+
" dropout.initialize(manager, dataType, scores.getShape());\n",
487+
" }\n",
488+
" }\n",
486489
"}"
487490
]
488491
},
@@ -508,12 +511,9 @@
508511
"source": [
509512
"queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);\n",
510513
"DotProductAttention productAttention = new DotProductAttention(0.5f);\n",
511-
"productAttention\n",
512-
" .forward(\n",
513-
" new ParameterStore(manager, false),\n",
514-
" new NDList(queries, keys, values, validLens),\n",
515-
" false)\n",
516-
" .head();"
514+
"input = new NDList(queries, keys, values, validLens);\n",
515+
"productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
516+
"productAttention.forward(ps, input, false).head();"
517517
]
518518
},
519519
{
@@ -562,13 +562,6 @@
562562
"1. Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?\n",
563563
"1. When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?\n"
564564
]
565-
},
566-
{
567-
"cell_type": "code",
568-
"execution_count": null,
569-
"metadata": {},
570-
"outputs": [],
571-
"source": []
572565
}
573566
],
574567
"metadata": {

0 commit comments

Comments
 (0)