|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Seq2Seq 기계 번역\n", |
| 8 | + "\n", |
| 9 | + "이번 프로젝트에선 임의로 Seq2Seq 모델을 아주 간단화 시켰습니다.\n", |
| 10 | + "한 언어로 된 문장을 다른 언어로 된 문장으로 번역하는 덩치가 큰 모델이 아닌\n", |
| 11 | + "영어 알파벳 문자열(\"hello\")을 스페인어 알파벳 문자열(\"hola\")로 번역하는 Mini Seq2Seq 모델을 같이 구현해 보겠습니다." |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "code", |
| 16 | + "execution_count": 115, |
| 17 | + "metadata": {}, |
| 18 | + "outputs": [], |
| 19 | + "source": [ |
| 20 | + "import torch\n", |
| 21 | + "import torch.nn as nn\n", |
| 22 | + "import torch.nn.functional as F\n", |
| 23 | + "import random\n", |
| 24 | + "import matplotlib.pyplot as plt" |
| 25 | + ] |
| 26 | + }, |
| 27 | + { |
| 28 | + "cell_type": "code", |
| 29 | + "execution_count": 116, |
| 30 | + "metadata": {}, |
| 31 | + "outputs": [ |
| 32 | + { |
| 33 | + "name": "stdout", |
| 34 | + "output_type": "stream", |
| 35 | + "text": [ |
| 36 | + "hello -> [104, 101, 108, 108, 111]\n", |
| 37 | + "hola -> [104, 111, 108, 97]\n" |
| 38 | + ] |
| 39 | + } |
| 40 | + ], |
| 41 | + "source": [ |
| 42 | + "vocab_size = 256 # 총 아스키 코드 개수\n", |
| 43 | + "x_ = list(map(ord, \"hello\")) # 아스키 코드 리스트로 변환\n", |
| 44 | + "y_ = list(map(ord, \"hola\")) # 아스키 코드 리스트로 변환\n", |
| 45 | + "print(\"hello -> \", x_)\n", |
| 46 | + "print(\"hola -> \", y_)" |
| 47 | + ] |
| 48 | + }, |
| 49 | + { |
| 50 | + "cell_type": "code", |
| 51 | + "execution_count": 117, |
| 52 | + "metadata": {}, |
| 53 | + "outputs": [], |
| 54 | + "source": [ |
| 55 | + "x = torch.LongTensor(x_)\n", |
| 56 | + "y = torch.LongTensor(y_)" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "code", |
| 61 | + "execution_count": 118, |
| 62 | + "metadata": {}, |
| 63 | + "outputs": [], |
| 64 | + "source": [ |
| 65 | + "class Seq2Seq(nn.Module):\n", |
| 66 | + " def __init__(self, vocab_size, hidden_size):\n", |
| 67 | + " super(Seq2Seq, self).__init__()\n", |
| 68 | + " self.n_layers = 1\n", |
| 69 | + " self.hidden_size = hidden_size\n", |
| 70 | + " self.embedding = nn.Embedding(vocab_size, hidden_size)\n", |
| 71 | + " self.encoder = nn.GRU(hidden_size, hidden_size)\n", |
| 72 | + " self.decoder = nn.GRU(hidden_size, hidden_size)\n", |
| 73 | + " self.project = nn.Linear(hidden_size, vocab_size)\n", |
| 74 | + "\n", |
| 75 | + " def forward(self, inputs, targets):\n", |
| 76 | + " # 인코더에 들어갈 입력\n", |
| 77 | + " initial_state = self._init_state()\n", |
| 78 | + " embedding = self.embedding(inputs).unsqueeze(1)\n", |
| 79 | + " # embedding = [seq_len, batch_size, embedding_size]\n", |
| 80 | + " \n", |
| 81 | + " # 인코더 (Encoder)\n", |
| 82 | + " encoder_output, encoder_state = self.encoder(embedding, initial_state)\n", |
| 83 | + " # encoder_output = [seq_len, batch_size, hidden_size]\n", |
| 84 | + " # encoder_state = [n_layers, seq_len, hidden_size]\n", |
| 85 | + "\n", |
| 86 | + " # 디코더에 들어갈 입력\n", |
| 87 | + " decoder_state = encoder_state\n", |
| 88 | + " decoder_input = torch.LongTensor([0])\n", |
| 89 | + " \n", |
| 90 | + " # 디코더 (Decoder)\n", |
| 91 | + " outputs = []\n", |
| 92 | + " \n", |
| 93 | + " for i in range(targets.size()[0]):\n", |
| 94 | + " decoder_input = self.embedding(decoder_input).unsqueeze(1)\n", |
| 95 | + " decoder_output, decoder_state = self.decoder(decoder_input, decoder_state)\n", |
| 96 | + " projection = self.project(decoder_output)\n", |
| 97 | + " outputs.append(projection)\n", |
| 98 | + " \n", |
| 99 | + " #티처 포싱(Teacher Forcing) 사용\n", |
| 100 | + " decoder_input = torch.LongTensor([targets[i]])\n", |
| 101 | + "\n", |
| 102 | + " outputs = torch.stack(outputs).squeeze()\n", |
| 103 | + " return outputs\n", |
| 104 | + " \n", |
| 105 | + " def _init_state(self, batch_size=1):\n", |
| 106 | + " weight = next(self.parameters()).data\n", |
| 107 | + " return weight.new(self.n_layers, batch_size, self.hidden_size).zero_()" |
| 108 | + ] |
| 109 | + }, |
| 110 | + { |
| 111 | + "cell_type": "code", |
| 112 | + "execution_count": 119, |
| 113 | + "metadata": {}, |
| 114 | + "outputs": [ |
| 115 | + { |
| 116 | + "name": "stdout", |
| 117 | + "output_type": "stream", |
| 118 | + "text": [ |
| 119 | + "Seq2Seq(\n", |
| 120 | + " (embedding): Embedding(256, 16)\n", |
| 121 | + " (encoder): GRU(16, 16)\n", |
| 122 | + " (decoder): GRU(16, 16)\n", |
| 123 | + " (project): Linear(in_features=16, out_features=256, bias=True)\n", |
| 124 | + ")\n" |
| 125 | + ] |
| 126 | + } |
| 127 | + ], |
| 128 | + "source": [ |
| 129 | + "seq2seq = Seq2Seq(vocab_size, 16)\n", |
| 130 | + "print(seq2seq)" |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "code", |
| 135 | + "execution_count": 120, |
| 136 | + "metadata": {}, |
| 137 | + "outputs": [], |
| 138 | + "source": [ |
| 139 | + "criterion = nn.CrossEntropyLoss()\n", |
| 140 | + "optimizer = torch.optim.Adam(seq2seq.parameters(), lr=1e-3)" |
| 141 | + ] |
| 142 | + }, |
| 143 | + { |
| 144 | + "cell_type": "code", |
| 145 | + "execution_count": 121, |
| 146 | + "metadata": {}, |
| 147 | + "outputs": [ |
| 148 | + { |
| 149 | + "name": "stdout", |
| 150 | + "output_type": "stream", |
| 151 | + "text": [ |
| 152 | + "\n", |
| 153 | + " 반복:0 오차: 5.596976280212402\n", |
| 154 | + "['9', 'L', '\\x98', 'L']\n", |
| 155 | + "\n", |
| 156 | + " 반복:100 오차: 2.069061756134033\n", |
| 157 | + "['h', 'o', 'l', 'a']\n", |
| 158 | + "\n", |
| 159 | + " 반복:200 오차: 0.4633035957813263\n", |
| 160 | + "['h', 'o', 'l', 'a']\n", |
| 161 | + "\n", |
| 162 | + " 반복:300 오차: 0.19558477401733398\n", |
| 163 | + "['h', 'o', 'l', 'a']\n", |
| 164 | + "\n", |
| 165 | + " 반복:400 오차: 0.11498594284057617\n", |
| 166 | + "['h', 'o', 'l', 'a']\n", |
| 167 | + "\n", |
| 168 | + " 반복:500 오차: 0.07863441854715347\n", |
| 169 | + "['h', 'o', 'l', 'a']\n", |
| 170 | + "\n", |
| 171 | + " 반복:600 오차: 0.058321841061115265\n", |
| 172 | + "['h', 'o', 'l', 'a']\n", |
| 173 | + "\n", |
| 174 | + " 반복:700 오차: 0.04549476131796837\n", |
| 175 | + "['h', 'o', 'l', 'a']\n", |
| 176 | + "\n", |
| 177 | + " 반복:800 오차: 0.03673341125249863\n", |
| 178 | + "['h', 'o', 'l', 'a']\n", |
| 179 | + "\n", |
| 180 | + " 반복:900 오차: 0.030412672087550163\n", |
| 181 | + "['h', 'o', 'l', 'a']\n" |
| 182 | + ] |
| 183 | + } |
| 184 | + ], |
| 185 | + "source": [ |
| 186 | + "log = []\n", |
| 187 | + "for i in range(1000):\n", |
| 188 | + " prediction = seq2seq(x, y)\n", |
| 189 | + " loss = criterion(prediction, y)\n", |
| 190 | + " optimizer.zero_grad()\n", |
| 191 | + " loss.backward()\n", |
| 192 | + " optimizer.step()\n", |
| 193 | + " loss_val = loss.data\n", |
| 194 | + " log.append(loss_val)\n", |
| 195 | + " if i % 100 == 0:\n", |
| 196 | + " print(\"\\n 반복:%d 오차: %s\" % (i, loss_val.item()))\n", |
| 197 | + " _, top1 = prediction.data.topk(1, 1)\n", |
| 198 | + " print([chr(c) for c in top1.squeeze().numpy().tolist()])" |
| 199 | + ] |
| 200 | + }, |
| 201 | + { |
| 202 | + "cell_type": "code", |
| 203 | + "execution_count": 122, |
| 204 | + "metadata": {}, |
| 205 | + "outputs": [ |
| 206 | + { |
| 207 | + "data": { |
| 208 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAfw0lEQVR4nO3deZgddb3n8ff3nNNLOr0m3Ul6yUaA\nINljswkqcBUBQe64M8B1lHl4vI/7dcZHRu/Dc8dnvG5XL47KgOJ1x/EC3sEoAsomiECHJWQnCQSy\ndidk6U7S6/nOH1WdNDFLZamuc+p8Xs9znlPbqfpWCj6n+le/U2XujoiIpE8m6QJERCQeCngRkZRS\nwIuIpJQCXkQkpRTwIiIplUu6gJEaGxt92rRpSZchIlI0Fi9evM3dmw41r6ACftq0aXR0dCRdhohI\n0TCz9YebpyYaEZGUUsCLiKSUAl5EJKUU8CIiKaWAFxFJKQW8iEhKKeBFRFKq6AO+d2CI7z+6jsfX\nbEu6FBGRglL0AV+WzXDro+v42V8O29dfRKQkFX3AZzPGO+dM4sGVnXT3DiRdjohIwSj6gAe4cl4L\nfYN5/rBia9KliIgUjFQE/MIpDbTUVbLo+c1JlyIiUjBSEfCZjPHOuc08+mIXu/aqmUZEBFIS8BA0\n0wwMOfct25J0KSIiBSE1AT+ntY4p46r4zZJNSZciIlIQUhPwZsaV85p5fM02tvX0JV2OiEjiUhPw\nEDTT5B3uXapmGhGRVAX8zIk1nDqhmt88r2YaEZFUBbyZceXcFp5++TW27OpNuhwRkUSlKuABrpjX\njDv89gX1iReR0pa6gJ/RVM2ZzbVqphGRkpe6gIfgYutzr+7k1df2Jl2KiEhiUhnwV8xtBmDREjXT\niEjpSmXATx5XxfzJ9SzSj55EpISlMuAhaKZZtmk3a7t6ki5FRCQRqQ34d85pxgzdYVJESlasAW9m\nL5vZC2b2nJl1xLmtg02qq+SsaeP4zZJNuPtoblpEpCCMxhn8Re4+393bR2Fbr3PlvBbWdPawamv3\naG9aRCRxqW2iAbhs9iQyaqYRkRIVd8A7cL+ZLTazGw61gJndYGYdZtbR1dV1UjfeWF3B+ac2qplG\nREpS3AF/gbsvBC4DPmZmbzl4AXe/zd3b3b29qanppBdwxdxm1m/fywsbd530dYuIFLJYA97dN4bv\nncCvgbPj3N6hXDqrmVzGdAthESk5sQW8mY01s5rhYeASYGlc2zucuqoy2qc18NDKztHetIhIouI8\ng58IPGZmzwNPAb9199/HuL3DumjmBFZu6Wbzrn1JbF5EJBGxBby7r3P3eeFrlrv/r7i2dTQXnTEB\ngIdXndyLuCIihSzV3SSHnTahmtb6MWqmEZGSUhIBb2ZcdEYTj6/ZRt/gUNLliIiMipIIeAja4ff0\nD9Hx8o6kSxERGRUlE/DnnjKeXMZ4fM22pEsRERkVJRPwYytyzJ9cz5/Xbk+6FBGRUVEyAQ/wphnj\nWbJhJ7t7B5IuRUQkdiUV8OfNaCTv8NS615IuRUQkdiUV8Aum1FORy6iZRkRKQkkFfGVZlvZpDfx5\nrS60ikj6lVTAA7xpRiMrt3Szvacv6VJERGJVcgF/9vRxACxer/7wIpJuJRfwc1rrKM9mFPAiknol\nF/CVZVlmt9bSoYAXkZQruYAHaJ82jhc27KJ3QPelEZH0KsmAf+PUBvqH8izVY/xEJMVKNuABNdOI\nSKqVZMA3VlcwvXGs7iwpIqlWkgEPwVn8M6/swN2TLkVEJBYlG/DtUxt4bU8/67btSboUEZFYlGzA\nLwzb4Z97ZWfClYiIxKNkA35GUzVV5VmWbFDAi0g6lWzAZzPGnNY6nt+grpIikk4lG/AA8ybXs3zT\nbvoH80mXIiJy0pV0wM9tq6N/KM+qLd1JlyIictKVdMDPa6sH4Hm1w4tICpV0wLc1jKGhqkwXWkUk\nlUo64M2MuW31LNGFVhFJodgD3syyZvasmS2Ke1vHY97kelZv7WZv/2DSpYiInFSjcQb/KWDFKGzn\nuMxrqyPvsGzT7qRLERE5qY4a8Gb2KTOrtcDtZvaMmV0SZeVm1ga8E/jBiRYal7nDF1pfVTu8iKRL\nlDP4j7j7buASoAG4DvhKxPX/K/A54LAdzc3sBjPrMLOOrq6uiKs9eZpqKmipq9QPnkQkdaIEvIXv\nlwM/dfdlI6Yd/kNmVwCd7r74SMu5+23u3u7u7U1NTRHKOflmt9axbJMCXkTSJUrALzaz+wkC/j4z\nq+EIZ+QjnA+8y8xeBn4JXGxmPzvuSmM0u7WOl7btoadPF1pFJD2iBPz1wOeBs9x9L1AGfPhoH3L3\nG929zd2nAR8EHnT3a0+k2LjMbq3FHVZs1oVWEUmPKAF/HrDK3Xea2bXAF4FUtWfMaqkDYJme0Soi\nKRIl4G8B9prZPOCzwFrgJ8eyEXd/2N2vOI76RsWEmgoaqytYqq6SIpIiUQJ+0IPn2l0FfMfdvwvU\nxFvW6DIzZrfWslRn8CKSIlECvtvMbiToHvlbM8sQtMOnyqyWWtZ09tA7MJR0KSIiJ0WUgP8A0EfQ\nH34L0AZ8PdaqEjC7pY7BvLN6q24dLCLpcNSAD0P950Bd2Le9192PqQ2+GMxuDS60Lt2odngRSYco\ntyp4P/AU8D7g/cCTZvbeuAsbbW0NY6itzOkHTyKSGrkIy3yBoA98J4CZNQF/AO6Ms7DRZmbMaqlT\nTxoRSY0obfCZ4XAPbY/4uaIzq6WWlZt3MzikZ7SKSPGLcgb/ezO7D7gjHP8A8Lv4SkrO7NY6+gbz\nrO3aw8xJqeoJKiIl6KgB7+7/3czeQ3BvGYDb3P3X8ZaVjNmttQAs3bhLAS8iRS/KGTzufhdwV8y1\nJG56YzVjyrIs3bSL97yxLelyREROyGED3sy6AT/ULMDdvTa2qhKSzRhvaK5hmbpKikgKHDbg3b0k\n2yhmt9Zx9zMbyeedTOaot70XESlYqewNcyJmtdTS0zfI+tf2Jl2KiMgJUcAfZPjWwbrxmIgUOwX8\nQU6fWENZ1limHzyJSJGLcquCT5hZw2gUUwjKcxlOn1ijWxaISNGLcgY/EXjazH5lZpeaWeqvPM5u\nqWPpxl0Et8EXESlOUe4m+UXgNOB24L8AL5rZl81sRsy1JWZ2ay079g6weVdv0qWIiBy3SG3w4ROd\ntoSvQaABuNPMvhZjbYk5UxdaRSQForTBf8rMFgNfAx4H5rj73wNvBN4Tc32JeENzDRlDd5YUkaIW\n5VYF44B3u/v6kRPdPR8+ACR1qspznNJUzXJdaBWRIhblZmM3mdlCM7uK4NYFj7v7M+G8FXEXmJTZ\nLbU8sW570mWIiBy3KE00/wj8GBgPNAL/ZmZfjLuwpM1pq2fr7j46d+tCq4gUpygXWa8leKLTTe5+\nE3AucF28ZSVvbltwoXXJBjXTiEhxihLwm4DKEeMVwMZ4yikcs1pqyRgsUU8aESlSUS6y7gKWmdkD\nBG3wbweeMrNvA7j7J2OsLzFV5TlOm1DDCxt2Jl2KiMhxiRLwvw5fwx6OsmIzqwQeJTjjzwF3hk08\nRWNOWx0Pr+rE3SmBH/CKSMpE6UXzYzMrB04PJ61y94EI6+4DLnb3HjMrAx4zs3vd/S8nUO+omttW\nx52LN7B5Vy8t9WOSLkdE5JhE6UVzIfAi8F3ge8BqM3vL0T7ngZ5wtCx8FdXNXea21QOwRM00IlKE\nolxk/RfgEnd/q7u/BXgH8K0oKzezrJk9B3QCD7j7k4dY5gYz6zCzjq6urmOpPXZnTKohlzH1pBGR\nohQl4MvcfdXwiLuvJjgbPyp3H3L3+UAbcLaZzT7EMre5e7u7tzc1NUWte1RUlmWZOamGF9STRkSK\nUJSA7zCzH5jZheHr+0DHsWzE3XcCDwGXHk+RSZrbVseSDbp1sIgUnygB//fAcuCT4Wt5OO2IzKzJ\nzOrD4TEE3StXHn+pyZjbVs+ufQO8+tq+pEsRETkmR+xFY2ZZ4Ifufg3wzWNcdzPw43AdGeBX7r7o\n+MpMzpzW4Betz2/YyZTxVQlXIyIS3RED3t2HzGyqmZW7e/+xrNjdlwALTqi6AnD6xBrKcxle2LiL\nK+e1JF2OiEhkUX7otA543MzuAfYMT3T3Yz2jL0rluQxvaK5VV0kRKTpR2uDXAovCZWvCV3WcRRWa\neW11vLBhF0N5XWgVkeIR5Qx+ubv/+8gJZva+mOopSAum1POTJ9bzYmc3Z0yqTbocEZFIopzB3xhx\nWmotmNwAwDPr1UwjIsXjsGfwZnYZcDnQOnznyFAtwYO3S8bU8VWMG1vOM6/s4D+fMyXpckREIjlS\nE80mgh80vQtYPGJ6N/CZOIsqNGbGwin1PPPKjqRLERGJ7LAB7+7PA8+b2S8i3j0y1RZMaeAPKzrZ\nubef+qrypMsRETmqKG3wZ5vZA2a22szWmdlLZrYu9soKzIIpwZ0ln31V7fAiUhyi9KK5naBJZjEw\nFG85hWteWz0Zg2fX7+CimROSLkdE5KgiPbLP3e+NvZICN7YixxmTannmFZ3Bi0hxiBLwD5nZ14G7\nCZ7SBIC7PxNbVQVq4dR6/uPZTQzlnWxGj/ATkcIWJeDPCd/bR0xz4OKTX05hWzC5gZ/95RXWdPYw\nc1JN0uWIiBxRlGeyXjQahRSDhVPDHzy9skMBLyIFL8ozWSea2e1mdm84fqaZXR9/aYVnWviDp8Xr\n1R9eRApflG6SPwLuA4bvlbsa+HRcBRWy4AdPDXS8/FrSpYiIHFWUgG90918BeQB3H6SEu0ueM30c\nL2/fS+fu3qRLERE5oigBv8fMxhNcWMXMzgVK9inUZ08fB8BTOosXkQIXJeD/AbgHmGFmjwM/AT4R\na1UFbFZLLVXlWZ56SQEvIoUtSi+aZ8zsrcBMwIBVpXxvmlw2wxunNijgRaTgRTmDx90H3X2Zuy8t\n5XAfdva0cazc0s3Ovcf0mFoRkVEVKeDl9Ybb4Z9+Wd0lRaRwKeCPw7zJ9ZRnMzz10vakSxEROawo\nP3Q638zGhsPXmtk3zWxq/KUVrsqyLPMn16sdXkQKWpQz+FuAvWY2D/gssJagJ01JO2t6A0s37aan\nr6SeXigiRSRKwA+6uwNXAd9x9+8CJX8jlvNOaWQo7zyts3gRKVBRAr7bzG4ErgV+a2YZoCzesgpf\n+7QGynMZHluzLelSREQOKUrAf4DgPvDXu/sWoA34eqxVFYHKsixnTWvgcQW8iBSoSGfwwM3u/icz\nOx2YD9xxtA+Z2WQze8jMlpvZMjP71IkWW2jOP7WRlVu66ezWfWlEpPBECfhHgQozawXuB64juMPk\n0QwCn3X3M4FzgY+Z2ZnHW2ghuuDURgCeWKvukiJSeKIEvLn7XuDdwPfc/X3A7KN9yN03Dz/Wz927\ngRVA64kUW2hmtdRRN6ZMzTQiUpAiBbyZnQdcA/z2GD43cgXTgAXAk4eYd4OZdZhZR1dX17GsNnHZ\njPGmGeN57MVtBB2NREQKR5Sg/jRwI/Brd19mZqcAD0XdgJlVA3cBn3b33QfPd/fb3L3d3dubmpqi\nrrZgnH9qI5t29fLy9r1JlyIi8jpR7ib5CPCImVWbWbW7rwM+GWXlZlZGEO4/d/e7T6zUwjTcDv/Y\ni11MbxybcDUiIgdEuVXBHDN7FlgGLDezxWY2K8LnDLgdWOHu3zzxUgvT1PFVTB43hkdWF1fzkoik\nX5QmmluBf3D3qe4+heB2Bd+P8LnzCXrcXGxmz4Wvy0+g1oJkZlw8cwKPrdlG70DJPslQRApQlIAf\n6+7729zd/WHgqG0R7v6Yu5u7z3X3+eHrdydQa8G6+A0T6R3I88Q6dZcUkcIRJeDXmdk/mtm08PVF\nYF3chRWTc6aPY0xZlgdXdCZdiojIflEC/iNAE3A3wQXTxnCahCrLslxwWiMPruxUd0kRKRhH7EVj\nZlngC+4eqddMKbv4jAk8sHwrq7f2MHNSyd9sU0QKwBHP4N19CLhglGopahfNnADAH1duTbgSEZFA\nlCaaZ83sHjO7zszePfyKvbIiM6muklkttTy0Uu3wIlIYogR8JbAduBi4MnxdEWdRxepvzpjA4vU7\n2NbTl3QpIiKRfsn64dEoJA0und3Mtx9cw33LtnDNOSX92FoRKQBRfsn6YzOrHzHeYGY/jLes4vSG\n5hqmN47l3he2JF2KiEikJpq57r5zeMTddxDcGVIOYmZcNnsST6zbzmt7+pMuR0RKXJSAz5hZw/CI\nmY0jQtNOqbp8TjNDeeeB5TqLF5FkRQn4fwGeMLMvmdmXgD8DX4u3rOI1q6WWKeOq+K2aaUQkYUcN\neHf/CcHTnLaGr3e7+0/jLqxYmRmXzZnEn9dsY4eaaUQkQZGezOTuy939O+FredxFFbt3zWthMO8s\nWrIp6VJEpIQd06P3JJozm2s5Y1INdz2zMelSRKSEKeBjYGa8e2Erz726k7VdPUmXIyIlSgEfk7+d\n30rG4Nc6ixeRhCjgYzKhtpI3n9bEr5/dSD6vWwiLyOhTwMfo3Qtb2bhzH39eqyc9icjoU8DH6B2z\nJtFQVcbPn1yfdCkiUoIU8DGqLMvy/rMmc//yrWzZ1Zt0OSJSYhTwMbvm7Knk3bnjqVeSLkVESowC\nPmZTxldx4elN3PHUKwwM5ZMuR0RKiAJ+FFx33lQ6u/u4d6nuTyMio0cBPwouPH0CpzSN5dZH1uKu\nLpMiMjoU8KMgkzE++pYZLNu0mz+9uC3pckSkRCjgR8lVC1qYVFvJLQ+vTboUESkRsQW8mf3QzDrN\nbGlc2ygmFbks//XN03li3XaefWVH0uWISAmI8wz+R8ClMa6/6Fx99hTqq8r41h9eTLoUESkBsQW8\nuz8KvBbX+ovR2IocH7vwVB5d3cWf16otXkTilXgbvJndYGYdZtbR1dWVdDmxu+68qbTUVfLV369S\njxoRiVXiAe/ut7l7u7u3NzU1JV1O7CrLsnz67afz/Ks71S9eRGKVeMCXovcsbGPmxBq+/LsV7Osf\nSrocEUkpBXwCshnjf141iw079vG/H9QFVxGJR5zdJO8AngBmmtkGM7s+rm0Vo3NOGc97Frbx/T+t\n48Wt3UmXIyIpFGcvmqvdvdndy9y9zd1vj2tbxep/XH4GVeU5brz7BYb01CcROcnURJOg8dUV3HTl\nmXSs38H/eUS/cBWRk0sBn7D/tKCVd85t5lsPrGbJhp1JlyMiKaKAT5iZ8eW/nUNTTQWfvONZdu0b\nSLokEUkJBXwBqKsq4+YPLmDDjn188o5n1R4vIieFAr5AnD19HP901SweWd3FV3+/MulyRCQFckkX\nIAdcc85UVm7u5rZH1zGxtpLrL5iedEkiUsQU8AXmpivPpKu7jy8tWk5NRY73nzU56ZJEpEipiabA\n5LIZbr56Pm8+rZHP372EXz71StIliUiRUsAXoIpclluveyNvPq2Jz9/9gvrIi8hxUcAXqKryHN//\nu3aunNfCV+5dyY13L6FvUDcmE5Ho1AZfwMpzGW7+wHwmN4zhew+vZcXmbm65diHNdWOSLk1EioDO\n4AtcJmN87tIzuOWahaze2s1lN/+JRUs2JV2WiBQBBXyRuGxOM7/5xAVMHT+Wj//iWT72i2fo3N2b\ndFkiUsAU8EVkRlM1d330PP7bJadz/7ItXPSNh7n1kbX0D+aTLk1ECpACvsjkshk+fvFpPPCZt3Le\njPH8870reds3H+FXT7/KwJCCXkQOUMAXqWmNY/nBh87ixx85m7oxZXzuriVc9I2H+fmT69nbP5h0\neSJSAMy9cG5s1d7e7h0dHUmXUXTcnYdWdXLzH9fw/Ks7qanM8d43tnHtuVOZ0VSddHkiEiMzW+zu\n7Yecp4BPD3enY/0OfvrEeu5dupmBIWf+5HreNa+FK+Y2M6G2MukSReQkU8CXoK7uPu5cvIF7nt/E\nis27yVhwx8qLz5jARTMncOqEasws6TJF5AQp4Evcms5u7nl+M/cv28LKLcEDvlvrx/CW05s4Z/o4\nzpo+jtZ6/XhKpBgp4GW/TTv38fCqLh5c2clf1m2npy+4INtaP4azpjUwf3I9s1vreENzLWMr9ENn\nkUKngJdDGso7Kzbv5umXX+Ppl1/jqZd2sK2nDwAzmN44ltktdcycVMOMpmpOnTCWKePGUp5T5yuR\nQqGAl0jcna27+1i6cRfLNu1m6aZdLNu4i027DvxiNpcxpoyvYkZTNdPGV9HWUEVr/Rjaxo2htX4M\nNZVlCe6BSOk5UsDrb3DZz8yYVFfJpLpK3nbmxP3Tu3sHeGnbHtZ29bC2M3hf09nDo6u76DvoV7R1\nY8porR9DS30lTTUVNNUE7xNqKoLx6uC9siw72rsnUnIU8HJUNZVlzG2rZ25b/eumuzvbevrZuHMf\nG3bsZeOOfWzYEQ7v7OW5V3eyfU8/h/ojsbYyx7ix5dRVldNQVUb9mDLqq8qpryqjIXyvD+fVVpZR\nXZmjuiJHRS6j3j8iESng5biZWXiWXsH8yfWHXGZwKM/2Pf10dffR1d1HZ3dv+N7Hzr0D7Njbz/ae\nftZ29bBzzwDdfUf+FW4uY/vDvroiR83wcGUZ1RVZqityjK3IUVWeZUxZlsqyLGPKs1SVh8Ph+Jhw\nuDIcLsvquoKkjwJeYpXLZphYW8nEiD+yGhjKs2vfADv39odfAAN09w7Q0zdId+8gPX2D9PQOsqdv\nkO5weFtPPy9v3xvOH6B34NjvyZPL2P7gryzLUp7LUJ7NBO+5DBXh6+Dp5dksFWUHpo1cJpgerCuX\nNXIZI5fJUJY1ctkMuYxRlh0xL5uhLHzPZY2yzIF5+qtFjkesAW9mlwI3A1ngB+7+lTi3J8WvLJuh\nsbqCxuqK417H4FCe3sE8+/qH6B0YYt/AEPv6w/fh4XC896B5w+P9Q3n6B/P0ha/u3kG2D+b3T+8P\nh/sGgmUHhuLtrJDN2EFfCMNfFMHw8BdELmNkMkbWIJfJkMkEn82Y7V/H8HCwnI34TPieCT9rwfDh\nlzvK+jKQseDLKWPBcMYIxw9Ms/3zRs4Pthtl+Uwm4vrCaZbhqMunRWwBb2ZZ4LvA24ENwNNmdo+7\nL49rmyIQ/NVQnc1QPYr9+PN5D8J/xBdD/4jXQD7PUN4ZGMozOOQM5oMvheHh10/LM5j3/cMDeWco\nXGbgEMsN5sPlhoLlhjyoZzCfJ58P/ioayjt5D7aXd2co7wy5h8sF70PuDOUJl8uTd/5quVIx8gsB\nC8aNA18GBuH010+z8AtiePjAdAvnHWIaMH5sBb/66HknfT/i/D/gbGCNu68DMLNfAlcBCnhJnUzG\nqMxkU9876MAXQfiFMfzlsP+LgGD+UDDu7uSd/e95Dz7n+4cJx8Ph/DEu/7r1B9s/puV9xPL5IyxP\nsIzvX/frp40cD74Hg1pGTvvr5Q+sr6YyniiOM+BbgVdHjG8Azjl4ITO7AbgBYMqUKTGWIyInKpMx\nMhgp/x5LjcS7Drj7be7e7u7tTU1NSZcjIpIacQb8RmDyiPG2cJqIiIyCOAP+aeA0M5tuZuXAB4F7\nYtyeiIiMEFsbvLsPmtnHgfsIukn+0N2XxbU9ERF5vVj7kbn774DfxbkNERE5tMQvsoqISDwU8CIi\nKaWAFxFJqYJ64IeZdQHrj/PjjcC2k1hOMdA+lwbtc/qdyP5OdfdD/oiooAL+RJhZx+GeapJW2ufS\noH1Ov7j2V000IiIppYAXEUmpNAX8bUkXkADtc2nQPqdfLPubmjZ4ERF5vTSdwYuIyAgKeBGRlCr6\ngDezS81slZmtMbPPJ13PyWJmk83sITNbbmbLzOxT4fRxZvaAmb0YvjeE083Mvh3+Oywxs4XJ7sHx\nM7OsmT1rZovC8elm9mS4b/83vDspZlYRjq8J509Lsu7jZWb1Znanma00sxVmdl7aj7OZfSb873qp\nmd1hZpVpO85m9kMz6zSzpSOmHfNxNbMPhcu/aGYfOpYaijrgRzz39TLgTOBqMzsz2apOmkHgs+5+\nJnAu8LFw3z4P/NHdTwP+GI5D8G9wWvi6Abhl9Es+aT4FrBgx/lXgW+5+KrADuD6cfj2wI5z+rXC5\nYnQz8Ht3PwOYR7DvqT3OZtYKfBJod/fZBHeb/SDpO84/Ai49aNoxHVczGwfcRPA0vLOBm4a/FCLx\n8PmDxfgCzgPuGzF+I3Bj0nXFtK//j+AB5quA5nBaM7AqHL4VuHrE8vuXK6YXwYNh/ghcDCwieCbx\nNiB38DEnuBX1eeFwLlzOkt6HY9zfOuClg+tO83HmwOM8x4XHbRHwjjQeZ2AasPR4jytwNXDriOmv\nW+5or6I+g+fQz31tTaiW2IR/ki4AngQmuvvmcNYWYGI4nJZ/i38FPgfkw/HxwE53HwzHR+7X/n0O\n5+8Kly8m04Eu4N/CZqkfmNlYUnyc3X0j8A3gFWAzwXFbTLqP87BjPa4ndLyLPeBTz8yqgbuAT7v7\n7pHzPPhKT00/VzO7Auh098VJ1zKKcsBC4BZ3XwDs4cCf7UAqj3MDcBXBl1sLMJa/bspIvdE4rsUe\n8Kl+7quZlRGE+8/d/e5w8lYzaw7nNwOd4fQ0/FucD7zLzF4GfknQTHMzUG9mww+nGblf+/c5nF8H\nbB/Ngk+CDcAGd38yHL+TIPDTfJzfBrzk7l3uPgDcTXDs03ychx3rcT2h413sAZ/a576amQG3Ayvc\n/ZsjZt0DDF9J/xBB2/zw9L8Lr8afC+wa8adgUXD3G929zd2nERzLB939GuAh4L3hYgfv8/C/xXvD\n5YvqTNfdtwCvmtnMcNLfAMtJ8XEmaJo518yqwv/Oh/c5tcd5hGM9rvcBl5hZQ/iXzyXhtGiSvghx\nEi5iXA6sBtYCX0i6npO4XxcQ/Pm2BHgufF1O0Pb4R+BF4A/AuHB5I+hRtBZ4gaCHQuL7cQL7fyGw\nKBw+BXgKWAP8O1ARTq8Mx9eE809Juu7j3Nf5QEd4rP8DaEj7cQb+CVgJLAV+ClSk7TgDdxBcYxgg\n+Evt+uM5rsBHwn1fA3z4WGrQrQpERFKq2JtoRETkMBTwIiIppYAXEUkpBbyISEop4EVEUkoBLyKS\nUgp4EZGU+v+vsJWbtkFCpQAAAABJRU5ErkJggg==\n", |
| 209 | + "text/plain": [ |
| 210 | + "<Figure size 432x288 with 1 Axes>" |
| 211 | + ] |
| 212 | + }, |
| 213 | + "metadata": {}, |
| 214 | + "output_type": "display_data" |
| 215 | + } |
| 216 | + ], |
| 217 | + "source": [ |
| 218 | + "plt.plot(log)\n", |
| 219 | + "plt.ylabel('cross entropy loss')\n", |
| 220 | + "plt.show()" |
| 221 | + ] |
| 222 | + } |
| 223 | + ], |
| 224 | + "metadata": { |
| 225 | + "kernelspec": { |
| 226 | + "display_name": "Python 3", |
| 227 | + "language": "python", |
| 228 | + "name": "python3" |
| 229 | + }, |
| 230 | + "language_info": { |
| 231 | + "codemirror_mode": { |
| 232 | + "name": "ipython", |
| 233 | + "version": 3 |
| 234 | + }, |
| 235 | + "file_extension": ".py", |
| 236 | + "mimetype": "text/x-python", |
| 237 | + "name": "python", |
| 238 | + "nbconvert_exporter": "python", |
| 239 | + "pygments_lexer": "ipython3", |
| 240 | + "version": "3.6.1" |
| 241 | + } |
| 242 | + }, |
| 243 | + "nbformat": 4, |
| 244 | + "nbformat_minor": 2 |
| 245 | +} |
0 commit comments