|
92 | 92 | "metadata": {},
|
93 | 93 | "outputs": [],
|
94 | 94 | "source": [
|
95 |
| - "NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));" |
| 95 | + "NDManager manager = NDManager.newBaseManager();" |
96 | 96 | ]
|
97 | 97 | },
|
98 | 98 | {
|
|
136 | 136 | " // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray\n",
|
137 | 137 | " if (validLens == null) {\n",
|
138 | 138 | " return X.softmax(-1);\n",
|
| 139 | + " }\n", |
| 140 | + " \n", |
| 141 | + " Shape shape = X.getShape();\n", |
| 142 | + " if (validLens.getShape().dimension() == 1) {\n", |
| 143 | + " validLens = validLens.repeat(shape.get(1));\n", |
139 | 144 | " } else {\n",
|
140 |
| - " Shape shape = X.getShape();\n", |
141 |
| - " if (validLens.getShape().dimension() == 1) {\n", |
142 |
| - " validLens = validLens.repeat(shape.get(1));\n", |
143 |
| - " } else {\n", |
144 |
| - " validLens = validLens.reshape(-1);\n", |
145 |
| - " }\n", |
146 |
| - " // On the last axis, replace masked elements with a very large negative\n", |
147 |
| - " // value, whose exponentiation outputs 0\n", |
148 |
| - " X =\n", |
149 |
| - " X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n", |
150 |
| - " .sequenceMask(validLens, (float) -1E6);\n", |
151 |
| - " return X.softmax(-1).reshape(shape);\n", |
| 145 | + " validLens = validLens.reshape(-1);\n", |
152 | 146 | " }\n",
|
| 147 | + " // On the last axis, replace masked elements with a very large negative\n", |
| 148 | + " // value, whose exponentiation outputs 0\n", |
| 149 | + " X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n", |
| 150 | + " .sequenceMask(validLens, (float) -1E6);\n", |
| 151 | + " return X.softmax(-1).reshape(shape);\n", |
153 | 152 | "}"
|
154 | 153 | ]
|
155 | 154 | },
|
|
174 | 173 | "metadata": {},
|
175 | 174 | "outputs": [],
|
176 | 175 | "source": [
|
177 |
| - "System.out.println(\n", |
178 |
| - " maskedSoftmax(\n", |
179 |
| - " manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", |
180 |
| - " manager.create(new float[] {2, 3})));" |
| 176 | + "maskedSoftmax(\n", |
| 177 | + " manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", |
| 178 | + " manager.create(new float[] {2, 3}));" |
181 | 179 | ]
|
182 | 180 | },
|
183 | 181 | {
|
|
198 | 196 | "metadata": {},
|
199 | 197 | "outputs": [],
|
200 | 198 | "source": [
|
201 |
| - "System.out.println(\n", |
202 |
| - " maskedSoftmax(\n", |
203 |
| - " manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", |
204 |
| - " manager.create(new float[][] {{1, 3}, {2, 4}})));" |
| 199 | + "maskedSoftmax(\n", |
| 200 | + " manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", |
| 201 | + " manager.create(new float[][] {{1, 3}, {2, 4}}));" |
205 | 202 | ]
|
206 | 203 | },
|
207 | 204 | {
|
|
243 | 240 | "outputs": [],
|
244 | 241 | "source": [
|
245 | 242 | "/* Additive attention. */\n",
|
246 |
| - "public class AdditiveAttention extends AbstractBlock {\n", |
247 |
| - " private static final byte VERSION = 1;\n", |
| 243 | + "public static class AdditiveAttention extends AbstractBlock {\n", |
| 244 | + "\n", |
248 | 245 | " private Linear W_k;\n",
|
249 | 246 | " private Linear W_q;\n",
|
250 | 247 | " private Linear W_v;\n",
|
251 | 248 | " private Dropout dropout;\n",
|
252 | 249 | " public NDArray attentionWeights;\n",
|
253 | 250 | "\n",
|
254 | 251 | " public AdditiveAttention(int numHiddens, float dropout) {\n",
|
255 |
| - " super(VERSION);\n", |
256 |
| - " this.W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", |
257 |
| - " this.addChildBlock(\"W_k\", this.W_k);\n", |
| 252 | + " W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", |
| 253 | + " addChildBlock(\"W_k\", W_k);\n", |
258 | 254 | "\n",
|
259 |
| - " this.W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", |
260 |
| - " this.addChildBlock(\"W_q\", this.W_q);\n", |
| 255 | + " W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", |
| 256 | + " addChildBlock(\"W_q\", W_q);\n", |
261 | 257 | "\n",
|
262 |
| - " this.W_v = Linear.builder().setUnits(1).optBias(false).build();\n", |
263 |
| - " this.addChildBlock(\"W_v\", this.W_v);\n", |
| 258 | + " W_v = Linear.builder().setUnits(1).optBias(false).build();\n", |
| 259 | + " addChildBlock(\"W_v\", W_v);\n", |
264 | 260 | "\n",
|
265 | 261 | " this.dropout = Dropout.builder().optRate(dropout).build();\n",
|
266 |
| - " this.addChildBlock(\"dropout\", this.dropout);\n", |
| 262 | + " addChildBlock(\"dropout\", this.dropout);\n", |
267 | 263 | " }\n",
|
268 | 264 | "\n",
|
269 | 265 | " @Override\n",
|
270 | 266 | " protected NDList forwardInternal(\n",
|
271 |
| - " ParameterStore parameterStore,\n", |
| 267 | + " ParameterStore ps,\n", |
272 | 268 | " NDList inputs,\n",
|
273 | 269 | " boolean training,\n",
|
274 | 270 | " PairList<String, Object> params) {\n",
|
|
279 | 275 | " NDArray values = inputs.get(2);\n",
|
280 | 276 | " NDArray validLens = inputs.get(3);\n",
|
281 | 277 | "\n",
|
282 |
| - " queries = this.W_q.forward(parameterStore, new NDList(queries), training, params).head();\n", |
283 |
| - " keys = this.W_k.forward(parameterStore, new NDList(keys), training, params).head();\n", |
| 278 | + " queries = W_q.forward(ps, new NDList(queries), training, params).head();\n", |
| 279 | + " keys = W_k.forward(ps, new NDList(keys), training, params).head();\n", |
284 | 280 | " // After dimension expansion, shape of `queries`: (`batchSize`, no. of\n",
|
285 | 281 | " // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,\n",
|
286 | 282 | " // no. of key-value pairs, `numHiddens`). Sum them up with\n",
|
|
290 | 286 | " // There is only one output of `this.W_v`, so we remove the last\n",
|
291 | 287 | " // one-dimensional entry from the shape. Shape of `scores`:\n",
|
292 | 288 | " // (`batchSize`, no. of queries, no. of key-value pairs)\n",
|
293 |
| - " NDArray result =\n", |
294 |
| - " this.W_v.forward(parameterStore, new NDList(features), training, params).head();\n", |
| 289 | + " NDArray result = W_v.forward(ps, new NDList(features), training, params).head();\n", |
295 | 290 | " NDArray scores = result.squeeze(-1);\n",
|
296 |
| - " this.attentionWeights = maskedSoftmax(scores, validLens);\n", |
297 |
| - " // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n", |
298 |
| - " // dimension)\n", |
299 |
| - " return new NDList(\n", |
300 |
| - " this.dropout\n", |
301 |
| - " .forward(\n", |
302 |
| - " parameterStore, new NDList(this.attentionWeights), training, params)\n", |
303 |
| - " .head()\n", |
304 |
| - " .batchDot(values));\n", |
| 291 | + " attentionWeights = maskedSoftmax(scores, validLens);\n", |
| 292 | + " // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)\n", |
| 293 | + " NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n", |
| 294 | + " return new NDList(list.head().batchDot(values));\n", |
305 | 295 | " }\n",
|
306 | 296 | "\n",
|
307 | 297 | " @Override\n",
|
|
310 | 300 | " }\n",
|
311 | 301 | "\n",
|
312 | 302 | " @Override\n",
|
313 |
| - " public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n", |
314 |
| - "}" |
| 303 | + " public void initializeChildBlocks(\n", |
| 304 | + " NDManager manager, DataType dataType, Shape... inputShapes) {\n", |
| 305 | + " W_q.initialize(manager, dataType, inputShapes[0]);\n", |
| 306 | + " W_k.initialize(manager, dataType, inputShapes[1]);\n", |
| 307 | + " long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();\n", |
| 308 | + " long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();\n", |
| 309 | + " long w = Math.max(q[q.length - 2], k[k.length - 2]);\n", |
| 310 | + " long h = Math.max(q[q.length - 1], k[k.length - 1]);\n", |
| 311 | + " long[] shape = new long[] {2, 1, w, h};\n", |
| 312 | + " W_v.initialize(manager, dataType, new Shape(shape));\n", |
| 313 | + " long[] dropoutShape = new long[shape.length - 1];\n", |
| 314 | + " System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);\n", |
| 315 | + " dropout.initialize(manager, dataType, new Shape(dropoutShape));\n", |
| 316 | + " }\n", |
| 317 | + "}\n" |
315 | 318 | ]
|
316 | 319 | },
|
317 | 320 | {
|
|
343 | 346 | "NDArray validLens = manager.create(new float[] {2, 6});\n",
|
344 | 347 | "\n",
|
345 | 348 | "AdditiveAttention attention = new AdditiveAttention(8, 0.1f);\n",
|
346 |
| - "attention\n", |
347 |
| - " .forward(\n", |
348 |
| - " new ParameterStore(manager, false),\n", |
349 |
| - " new NDList(queries, keys, values, validLens),\n", |
350 |
| - " false)\n", |
351 |
| - " .head();" |
| 349 | + "NDList input = new NDList(queries, keys, values, validLens);\n", |
| 350 | + "ParameterStore ps = new ParameterStore(manager, false);\n", |
| 351 | + "attention.initialize(manager, DataType.FLOAT32, input.getShapes());\n", |
| 352 | + "attention.forward(ps, input, false).head();" |
352 | 353 | ]
|
353 | 354 | },
|
354 | 355 | {
|
|
435 | 436 | "outputs": [],
|
436 | 437 | "source": [
|
437 | 438 | "/* Scaled dot product attention. */\n",
|
438 |
| - "public class DotProductAttention extends AbstractBlock {\n", |
439 |
| - " private static final byte VERSION = 1;\n", |
| 439 | + "public static class DotProductAttention extends AbstractBlock {\n", |
| 440 | + "\n", |
440 | 441 | " private Dropout dropout;\n",
|
441 | 442 | " public NDArray attentionWeights;\n",
|
442 | 443 | "\n",
|
443 | 444 | " public DotProductAttention(float dropout) {\n",
|
444 |
| - " super(VERSION);\n", |
445 |
| - "\n", |
446 | 445 | " this.dropout = Dropout.builder().optRate(dropout).build();\n",
|
447 | 446 | " this.addChildBlock(\"dropout\", this.dropout);\n",
|
448 | 447 | " this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);\n",
|
449 | 448 | " }\n",
|
450 | 449 | "\n",
|
451 | 450 | " @Override\n",
|
452 | 451 | " protected NDList forwardInternal(\n",
|
453 |
| - " ParameterStore parameterStore,\n", |
| 452 | + " ParameterStore ps,\n", |
454 | 453 | " NDList inputs,\n",
|
455 | 454 | " boolean training,\n",
|
456 | 455 | " PairList<String, Object> params) {\n",
|
|
459 | 458 | " // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n",
|
460 | 459 | " // dimension)\n",
|
461 | 460 | " // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)\n",
|
462 |
| - " NDArray queries = inputs.head();\n", |
| 461 | + " NDArray queries = inputs.get(0);\n", |
463 | 462 | " NDArray keys = inputs.get(1);\n",
|
464 | 463 | " NDArray values = inputs.get(2);\n",
|
465 | 464 | " NDArray validLens = inputs.get(3);\n",
|
|
468 | 467 | " // Swap the last two dimensions of `keys` and perform batchDot\n",
|
469 | 468 | " NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));\n",
|
470 | 469 | " attentionWeights = maskedSoftmax(scores, validLens);\n",
|
471 |
| - " return new NDList(\n", |
472 |
| - " this.dropout\n", |
473 |
| - " .forward(\n", |
474 |
| - " parameterStore, new NDList(this.attentionWeights), training, params)\n", |
475 |
| - " .head()\n", |
476 |
| - " .batchDot(values));\n", |
| 470 | + " NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n", |
| 471 | + " return new NDList(list.head().batchDot(values));\n", |
477 | 472 | " }\n",
|
478 | 473 | "\n",
|
479 | 474 | " @Override\n",
|
|
482 | 477 | " }\n",
|
483 | 478 | "\n",
|
484 | 479 | " @Override\n",
|
485 |
| - " public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n", |
| 480 | + " public void initializeChildBlocks(\n", |
| 481 | + " NDManager manager, DataType dataType, Shape... inputShapes) {\n", |
| 482 | + " try (NDManager sub = manager.newSubManager()) {\n", |
| 483 | + " NDArray queries = sub.zeros(inputShapes[0], dataType);\n", |
| 484 | + " NDArray keys = sub.zeros(inputShapes[1], dataType);\n", |
| 485 | + " NDArray scores = queries.batchDot(keys.swapAxes(1, 2));\n", |
| 486 | + " dropout.initialize(manager, dataType, scores.getShape());\n", |
| 487 | + " }\n", |
| 488 | + " }\n", |
486 | 489 | "}"
|
487 | 490 | ]
|
488 | 491 | },
|
|
508 | 511 | "source": [
|
509 | 512 | "queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);\n",
|
510 | 513 | "DotProductAttention productAttention = new DotProductAttention(0.5f);\n",
|
511 |
| - "productAttention\n", |
512 |
| - " .forward(\n", |
513 |
| - " new ParameterStore(manager, false),\n", |
514 |
| - " new NDList(queries, keys, values, validLens),\n", |
515 |
| - " false)\n", |
516 |
| - " .head();" |
| 514 | + "input = new NDList(queries, keys, values, validLens);\n", |
| 515 | + "productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());\n", |
| 516 | + "productAttention.forward(ps, input, false).head();" |
517 | 517 | ]
|
518 | 518 | },
|
519 | 519 | {
|
|
562 | 562 | "1. Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?\n",
|
563 | 563 | "1. When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?\n"
|
564 | 564 | ]
|
565 |
| - }, |
566 |
| - { |
567 |
| - "cell_type": "code", |
568 |
| - "execution_count": null, |
569 |
| - "metadata": {}, |
570 |
| - "outputs": [], |
571 |
| - "source": [] |
572 | 565 | }
|
573 | 566 | ],
|
574 | 567 | "metadata": {
|
|
0 commit comments