|  | 
| 92 | 92 |    "metadata": {}, | 
| 93 | 93 |    "outputs": [], | 
| 94 | 94 |    "source": [ | 
| 95 |  | -    "NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));" | 
|  | 95 | +    "NDManager manager = NDManager.newBaseManager();" | 
| 96 | 96 |    ] | 
| 97 | 97 |   }, | 
| 98 | 98 |   { | 
|  | 
| 136 | 136 |     "    // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray\n", | 
| 137 | 137 |     "    if (validLens == null) {\n", | 
| 138 | 138 |     "        return X.softmax(-1);\n", | 
|  | 139 | +    "    }\n", | 
|  | 140 | +    "    \n", | 
|  | 141 | +    "    Shape shape = X.getShape();\n", | 
|  | 142 | +    "    if (validLens.getShape().dimension() == 1) {\n", | 
|  | 143 | +    "        validLens = validLens.repeat(shape.get(1));\n", | 
| 139 | 144 |     "    } else {\n", | 
| 140 |  | -    "        Shape shape = X.getShape();\n", | 
| 141 |  | -    "        if (validLens.getShape().dimension() == 1) {\n", | 
| 142 |  | -    "            validLens = validLens.repeat(shape.get(1));\n", | 
| 143 |  | -    "        } else {\n", | 
| 144 |  | -    "            validLens = validLens.reshape(-1);\n", | 
| 145 |  | -    "        }\n", | 
| 146 |  | -    "        // On the last axis, replace masked elements with a very large negative\n", | 
| 147 |  | -    "        // value, whose exponentiation outputs 0\n", | 
| 148 |  | -    "        X =\n", | 
| 149 |  | -    "                X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n", | 
| 150 |  | -    "                        .sequenceMask(validLens, (float) -1E6);\n", | 
| 151 |  | -    "        return X.softmax(-1).reshape(shape);\n", | 
|  | 145 | +    "        validLens = validLens.reshape(-1);\n", | 
| 152 | 146 |     "    }\n", | 
|  | 147 | +    "    // On the last axis, replace masked elements with a very large negative\n", | 
|  | 148 | +    "    // value, whose exponentiation outputs 0\n", | 
|  | 149 | +    "    X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))\n", | 
|  | 150 | +    "            .sequenceMask(validLens, (float) -1E6);\n", | 
|  | 151 | +    "    return X.softmax(-1).reshape(shape);\n", | 
| 153 | 152 |     "}" | 
| 154 | 153 |    ] | 
| 155 | 154 |   }, | 
|  | 
| 174 | 173 |    "metadata": {}, | 
| 175 | 174 |    "outputs": [], | 
| 176 | 175 |    "source": [ | 
| 177 |  | -    "System.out.println(\n", | 
| 178 |  | -    "        maskedSoftmax(\n", | 
| 179 |  | -    "                manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", | 
| 180 |  | -    "                manager.create(new float[] {2, 3})));" | 
|  | 176 | +    "maskedSoftmax(\n", | 
|  | 177 | +    "        manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", | 
|  | 178 | +    "        manager.create(new float[] {2, 3}));" | 
| 181 | 179 |    ] | 
| 182 | 180 |   }, | 
| 183 | 181 |   { | 
|  | 
| 198 | 196 |    "metadata": {}, | 
| 199 | 197 |    "outputs": [], | 
| 200 | 198 |    "source": [ | 
| 201 |  | -    "System.out.println(\n", | 
| 202 |  | -    "        maskedSoftmax(\n", | 
| 203 |  | -    "                manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", | 
| 204 |  | -    "                manager.create(new float[][] {{1, 3}, {2, 4}})));" | 
|  | 199 | +    "maskedSoftmax(\n", | 
|  | 200 | +    "        manager.randomUniform(0, 1, new Shape(2, 2, 4)),\n", | 
|  | 201 | +    "        manager.create(new float[][] {{1, 3}, {2, 4}}));" | 
| 205 | 202 |    ] | 
| 206 | 203 |   }, | 
| 207 | 204 |   { | 
|  | 
| 243 | 240 |    "outputs": [], | 
| 244 | 241 |    "source": [ | 
| 245 | 242 |     "/* Additive attention. */\n", | 
| 246 |  | -    "public class AdditiveAttention extends AbstractBlock {\n", | 
| 247 |  | -    "    private static final byte VERSION = 1;\n", | 
|  | 243 | +    "public static class AdditiveAttention extends AbstractBlock {\n", | 
|  | 244 | +    "\n", | 
| 248 | 245 |     "    private Linear W_k;\n", | 
| 249 | 246 |     "    private Linear W_q;\n", | 
| 250 | 247 |     "    private Linear W_v;\n", | 
| 251 | 248 |     "    private Dropout dropout;\n", | 
| 252 | 249 |     "    public NDArray attentionWeights;\n", | 
| 253 | 250 |     "\n", | 
| 254 | 251 |     "    public AdditiveAttention(int numHiddens, float dropout) {\n", | 
| 255 |  | -    "        super(VERSION);\n", | 
| 256 |  | -    "        this.W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", | 
| 257 |  | -    "        this.addChildBlock(\"W_k\", this.W_k);\n", | 
|  | 252 | +    "        W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", | 
|  | 253 | +    "        addChildBlock(\"W_k\", W_k);\n", | 
| 258 | 254 |     "\n", | 
| 259 |  | -    "        this.W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", | 
| 260 |  | -    "        this.addChildBlock(\"W_q\", this.W_q);\n", | 
|  | 255 | +    "        W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();\n", | 
|  | 256 | +    "        addChildBlock(\"W_q\", W_q);\n", | 
| 261 | 257 |     "\n", | 
| 262 |  | -    "        this.W_v = Linear.builder().setUnits(1).optBias(false).build();\n", | 
| 263 |  | -    "        this.addChildBlock(\"W_v\", this.W_v);\n", | 
|  | 258 | +    "        W_v = Linear.builder().setUnits(1).optBias(false).build();\n", | 
|  | 259 | +    "        addChildBlock(\"W_v\", W_v);\n", | 
| 264 | 260 |     "\n", | 
| 265 | 261 |     "        this.dropout = Dropout.builder().optRate(dropout).build();\n", | 
| 266 |  | -    "        this.addChildBlock(\"dropout\", this.dropout);\n", | 
|  | 262 | +    "        addChildBlock(\"dropout\", this.dropout);\n", | 
| 267 | 263 |     "    }\n", | 
| 268 | 264 |     "\n", | 
| 269 | 265 |     "    @Override\n", | 
| 270 | 266 |     "    protected NDList forwardInternal(\n", | 
| 271 |  | -    "            ParameterStore parameterStore,\n", | 
|  | 267 | +    "            ParameterStore ps,\n", | 
| 272 | 268 |     "            NDList inputs,\n", | 
| 273 | 269 |     "            boolean training,\n", | 
| 274 | 270 |     "            PairList<String, Object> params) {\n", | 
|  | 
| 279 | 275 |     "        NDArray values = inputs.get(2);\n", | 
| 280 | 276 |     "        NDArray validLens = inputs.get(3);\n", | 
| 281 | 277 |     "\n", | 
| 282 |  | -    "        queries = this.W_q.forward(parameterStore, new NDList(queries), training, params).head();\n", | 
| 283 |  | -    "        keys = this.W_k.forward(parameterStore, new NDList(keys), training, params).head();\n", | 
|  | 278 | +    "        queries = W_q.forward(ps, new NDList(queries), training, params).head();\n", | 
|  | 279 | +    "        keys = W_k.forward(ps, new NDList(keys), training, params).head();\n", | 
| 284 | 280 |     "        // After dimension expansion, shape of `queries`: (`batchSize`, no. of\n", | 
| 285 | 281 |     "        // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,\n", | 
| 286 | 282 |     "        // no. of key-value pairs, `numHiddens`). Sum them up with\n", | 
|  | 
| 290 | 286 |     "        // There is only one output of `this.W_v`, so we remove the last\n", | 
| 291 | 287 |     "        // one-dimensional entry from the shape. Shape of `scores`:\n", | 
| 292 | 288 |     "        // (`batchSize`, no. of queries, no. of key-value pairs)\n", | 
| 293 |  | -    "        NDArray result =\n", | 
| 294 |  | -    "                this.W_v.forward(parameterStore, new NDList(features), training, params).head();\n", | 
|  | 289 | +    "        NDArray result = W_v.forward(ps, new NDList(features), training, params).head();\n", | 
| 295 | 290 |     "        NDArray scores = result.squeeze(-1);\n", | 
| 296 |  | -    "        this.attentionWeights = maskedSoftmax(scores, validLens);\n", | 
| 297 |  | -    "        // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n", | 
| 298 |  | -    "        // dimension)\n", | 
| 299 |  | -    "        return new NDList(\n", | 
| 300 |  | -    "                this.dropout\n", | 
| 301 |  | -    "                        .forward(\n", | 
| 302 |  | -    "                                parameterStore, new NDList(this.attentionWeights), training, params)\n", | 
| 303 |  | -    "                        .head()\n", | 
| 304 |  | -    "                        .batchDot(values));\n", | 
|  | 291 | +    "        attentionWeights = maskedSoftmax(scores, validLens);\n", | 
|  | 292 | +    "        // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)\n", | 
|  | 293 | +    "        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n", | 
|  | 294 | +    "        return new NDList(list.head().batchDot(values));\n", | 
| 305 | 295 |     "    }\n", | 
| 306 | 296 |     "\n", | 
| 307 | 297 |     "    @Override\n", | 
|  | 
| 310 | 300 |     "    }\n", | 
| 311 | 301 |     "\n", | 
| 312 | 302 |     "    @Override\n", | 
| 313 |  | -    "    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n", | 
| 314 |  | -    "}" | 
|  | 303 | +    "    public void initializeChildBlocks(\n", | 
|  | 304 | +    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n", | 
|  | 305 | +    "        W_q.initialize(manager, dataType, inputShapes[0]);\n", | 
|  | 306 | +    "        W_k.initialize(manager, dataType, inputShapes[1]);\n", | 
|  | 307 | +    "        long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();\n", | 
|  | 308 | +    "        long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();\n", | 
|  | 309 | +    "        long w = Math.max(q[q.length - 2], k[k.length - 2]);\n", | 
|  | 310 | +    "        long h = Math.max(q[q.length - 1], k[k.length - 1]);\n", | 
|  | 311 | +    "        long[] shape = new long[] {2, 1, w, h};\n", | 
|  | 312 | +    "        W_v.initialize(manager, dataType, new Shape(shape));\n", | 
|  | 313 | +    "        long[] dropoutShape = new long[shape.length - 1];\n", | 
|  | 314 | +    "        System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);\n", | 
|  | 315 | +    "        dropout.initialize(manager, dataType, new Shape(dropoutShape));\n", | 
|  | 316 | +    "    }\n", | 
|  | 317 | +    "}\n" | 
| 315 | 318 |    ] | 
| 316 | 319 |   }, | 
| 317 | 320 |   { | 
|  | 
| 343 | 346 |     "NDArray validLens = manager.create(new float[] {2, 6});\n", | 
| 344 | 347 |     "\n", | 
| 345 | 348 |     "AdditiveAttention attention = new AdditiveAttention(8, 0.1f);\n", | 
| 346 |  | -    "attention\n", | 
| 347 |  | -    "        .forward(\n", | 
| 348 |  | -    "                new ParameterStore(manager, false),\n", | 
| 349 |  | -    "                new NDList(queries, keys, values, validLens),\n", | 
| 350 |  | -    "                false)\n", | 
| 351 |  | -    "        .head();" | 
|  | 349 | +    "NDList input = new NDList(queries, keys, values, validLens);\n", | 
|  | 350 | +    "ParameterStore ps = new ParameterStore(manager, false);\n", | 
|  | 351 | +    "attention.initialize(manager, DataType.FLOAT32, input.getShapes());\n", | 
|  | 352 | +    "attention.forward(ps, input, false).head();" | 
| 352 | 353 |    ] | 
| 353 | 354 |   }, | 
| 354 | 355 |   { | 
|  | 
| 435 | 436 |    "outputs": [], | 
| 436 | 437 |    "source": [ | 
| 437 | 438 |     "/* Scaled dot product attention. */\n", | 
| 438 |  | -    "public class DotProductAttention extends AbstractBlock {\n", | 
| 439 |  | -    "    private static final byte VERSION = 1;\n", | 
|  | 439 | +    "public static class DotProductAttention extends AbstractBlock {\n", | 
|  | 440 | +    "\n", | 
| 440 | 441 |     "    private Dropout dropout;\n", | 
| 441 | 442 |     "    public NDArray attentionWeights;\n", | 
| 442 | 443 |     "\n", | 
| 443 | 444 |     "    public DotProductAttention(float dropout) {\n", | 
| 444 |  | -    "        super(VERSION);\n", | 
| 445 |  | -    "\n", | 
| 446 | 445 |     "        this.dropout = Dropout.builder().optRate(dropout).build();\n", | 
| 447 | 446 |     "        this.addChildBlock(\"dropout\", this.dropout);\n", | 
| 448 | 447 |     "        this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);\n", | 
| 449 | 448 |     "    }\n", | 
| 450 | 449 |     "\n", | 
| 451 | 450 |     "    @Override\n", | 
| 452 | 451 |     "    protected NDList forwardInternal(\n", | 
| 453 |  | -    "            ParameterStore parameterStore,\n", | 
|  | 452 | +    "            ParameterStore ps,\n", | 
| 454 | 453 |     "            NDList inputs,\n", | 
| 455 | 454 |     "            boolean training,\n", | 
| 456 | 455 |     "            PairList<String, Object> params) {\n", | 
|  | 
| 459 | 458 |     "        // Shape of `values`: (`batchSize`, no. of key-value pairs, value\n", | 
| 460 | 459 |     "        // dimension)\n", | 
| 461 | 460 |     "        // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)\n", | 
| 462 |  | -    "        NDArray queries = inputs.head();\n", | 
|  | 461 | +    "        NDArray queries = inputs.get(0);\n", | 
| 463 | 462 |     "        NDArray keys = inputs.get(1);\n", | 
| 464 | 463 |     "        NDArray values = inputs.get(2);\n", | 
| 465 | 464 |     "        NDArray validLens = inputs.get(3);\n", | 
|  | 
| 468 | 467 |     "        // Swap the last two dimensions of `keys` and perform batchDot\n", | 
| 469 | 468 |     "        NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));\n", | 
| 470 | 469 |     "        attentionWeights = maskedSoftmax(scores, validLens);\n", | 
| 471 |  | -    "        return new NDList(\n", | 
| 472 |  | -    "                this.dropout\n", | 
| 473 |  | -    "                        .forward(\n", | 
| 474 |  | -    "                                parameterStore, new NDList(this.attentionWeights), training, params)\n", | 
| 475 |  | -    "                        .head()\n", | 
| 476 |  | -    "                        .batchDot(values));\n", | 
|  | 470 | +    "        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);\n", | 
|  | 471 | +    "        return new NDList(list.head().batchDot(values));\n", | 
| 477 | 472 |     "    }\n", | 
| 478 | 473 |     "\n", | 
| 479 | 474 |     "    @Override\n", | 
|  | 
| 482 | 477 |     "    }\n", | 
| 483 | 478 |     "\n", | 
| 484 | 479 |     "    @Override\n", | 
| 485 |  | -    "    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {}\n", | 
|  | 480 | +    "    public void initializeChildBlocks(\n", | 
|  | 481 | +    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n", | 
|  | 482 | +    "        try (NDManager sub = manager.newSubManager()) {\n", | 
|  | 483 | +    "            NDArray queries = sub.zeros(inputShapes[0], dataType);\n", | 
|  | 484 | +    "            NDArray keys = sub.zeros(inputShapes[1], dataType);\n", | 
|  | 485 | +    "            NDArray scores = queries.batchDot(keys.swapAxes(1, 2));\n", | 
|  | 486 | +    "            dropout.initialize(manager, dataType, scores.getShape());\n", | 
|  | 487 | +    "        }\n", | 
|  | 488 | +    "    }\n", | 
| 486 | 489 |     "}" | 
| 487 | 490 |    ] | 
| 488 | 491 |   }, | 
|  | 
| 508 | 511 |    "source": [ | 
| 509 | 512 |     "queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);\n", | 
| 510 | 513 |     "DotProductAttention productAttention = new DotProductAttention(0.5f);\n", | 
| 511 |  | -    "productAttention\n", | 
| 512 |  | -    "        .forward(\n", | 
| 513 |  | -    "                new ParameterStore(manager, false),\n", | 
| 514 |  | -    "                new NDList(queries, keys, values, validLens),\n", | 
| 515 |  | -    "                false)\n", | 
| 516 |  | -    "        .head();" | 
|  | 514 | +    "input = new NDList(queries, keys, values, validLens);\n", | 
|  | 515 | +    "productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());\n", | 
|  | 516 | +    "productAttention.forward(ps, input, false).head();" | 
| 517 | 517 |    ] | 
| 518 | 518 |   }, | 
| 519 | 519 |   { | 
|  | 
| 562 | 562 |     "1. Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?\n", | 
| 563 | 563 |     "1. When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?\n" | 
| 564 | 564 |    ] | 
| 565 |  | -  }, | 
| 566 |  | -  { | 
| 567 |  | -   "cell_type": "code", | 
| 568 |  | -   "execution_count": null, | 
| 569 |  | -   "metadata": {}, | 
| 570 |  | -   "outputs": [], | 
| 571 |  | -   "source": [] | 
| 572 | 565 |   } | 
| 573 | 566 |  ], | 
| 574 | 567 |  "metadata": { | 
|  | 
0 commit comments