|
20 | 20 | },
|
21 | 21 | {
|
22 | 22 | "cell_type": "code",
|
23 |
| - "execution_count": 10, |
| 23 | + "execution_count": 1, |
24 | 24 | "id": "4815068f-bb52-47fb-9d35-3a340663c758",
|
25 | 25 | "metadata": {},
|
26 | 26 | "outputs": [],
|
27 | 27 | "source": [
|
28 | 28 | "import gymnasium as gym\n",
|
29 |
| - "from utils.test_env import TestEnv\n", |
30 |
| - "from utils.grid_search import GridSearch\n", |
31 |
| - "from utils.blackjack_wrapper import BlackjackWrapper\n", |
32 |
| - "from algorithms.rl import RL\n", |
33 |
| - "from algorithms.planner import Planner" |
| 29 | + "from bettermdptools.utils.test_env import TestEnv\n", |
| 30 | + "from bettermdptools.utils.grid_search import GridSearch\n", |
| 31 | + "from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n", |
| 32 | + "from bettermdptools.algorithms.rl import RL\n", |
| 33 | + "from bettermdptools.algorithms.planner import Planner" |
34 | 34 | ]
|
35 | 35 | },
|
36 | 36 | {
|
37 | 37 | "cell_type": "code",
|
38 |
| - "execution_count": 4, |
| 38 | + "execution_count": 2, |
39 | 39 | "id": "448dc393-ba4f-412d-ac89-6c71a22cfeaf",
|
40 | 40 | "metadata": {},
|
41 | 41 | "outputs": [
|
|
50 | 50 | "name": "stdout",
|
51 | 51 | "output_type": "stream",
|
52 | 52 | "text": [
|
53 |
| - "runtime = 0.67 seconds\n", |
| 53 | + "runtime = 0.64 seconds\n", |
54 | 54 | "Avg. episode reward: 0.0\n",
|
55 | 55 | "###################\n",
|
56 | 56 | "running q_learning with gamma: 0.99 epsilon decay: 0.9 iterations: 5000\n"
|
57 | 57 | ]
|
58 | 58 | },
|
59 |
| - { |
60 |
| - "name": "stderr", |
61 |
| - "output_type": "stream", |
62 |
| - "text": [ |
63 |
| - " \r" |
64 |
| - ] |
65 |
| - }, |
66 | 59 | {
|
67 | 60 | "name": "stdout",
|
68 | 61 | "output_type": "stream",
|
69 | 62 | "text": [
|
70 |
| - "runtime = 6.56 seconds\n", |
| 63 | + "runtime = 6.49 seconds\n", |
71 | 64 | "Avg. episode reward: 0.0\n",
|
72 | 65 | "###################\n",
|
73 | 66 | "running q_learning with gamma: 0.99 epsilon decay: 0.9 iterations: 50000\n"
|
74 | 67 | ]
|
75 | 68 | },
|
76 |
| - { |
77 |
| - "name": "stderr", |
78 |
| - "output_type": "stream", |
79 |
| - "text": [ |
80 |
| - " " |
81 |
| - ] |
82 |
| - }, |
83 | 69 | {
|
84 | 70 | "name": "stdout",
|
85 | 71 | "output_type": "stream",
|
86 | 72 | "text": [
|
87 |
| - "runtime = 53.30 seconds\n", |
88 |
| - "Avg. episode reward: 0.83\n", |
| 73 | + "runtime = 54.57 seconds\n", |
| 74 | + "Avg. episode reward: 0.88\n", |
89 | 75 | "###################\n"
|
90 | 76 | ]
|
91 |
| - }, |
92 |
| - { |
93 |
| - "name": "stderr", |
94 |
| - "output_type": "stream", |
95 |
| - "text": [ |
96 |
| - "\r" |
97 |
| - ] |
98 | 77 | }
|
99 | 78 | ],
|
100 | 79 | "source": [
|
|
107 | 86 | },
|
108 | 87 | {
|
109 | 88 | "cell_type": "code",
|
110 |
| - "execution_count": 6, |
| 89 | + "execution_count": 3, |
111 | 90 | "id": "60161c93-6183-4de6-a937-7c12b9742fb7",
|
112 | 91 | "metadata": {},
|
113 | 92 | "outputs": [
|
|
116 | 95 | "output_type": "stream",
|
117 | 96 | "text": [
|
118 | 97 | "running VI with gamma: 0.7 n_iters: 500 theta: 0.001\n",
|
119 |
| - "runtime = 0.02 seconds\n", |
120 |
| - "Avg. episode reward: 0.01\n", |
| 98 | + "runtime = 0.01 seconds\n", |
| 99 | + "Avg. episode reward: -0.16\n", |
121 | 100 | "###################\n",
|
122 | 101 | "running VI with gamma: 0.7 n_iters: 500 theta: 1e-05\n",
|
123 | 102 | "runtime = 0.02 seconds\n",
|
124 |
| - "Avg. episode reward: -0.07\n", |
| 103 | + "Avg. episode reward: 0.18\n", |
125 | 104 | "###################\n",
|
126 | 105 | "running VI with gamma: 0.9 n_iters: 500 theta: 0.001\n",
|
127 | 106 | "runtime = 0.01 seconds\n",
|
128 |
| - "Avg. episode reward: -0.16\n", |
| 107 | + "Avg. episode reward: -0.06\n", |
129 | 108 | "###################\n",
|
130 | 109 | "running VI with gamma: 0.9 n_iters: 500 theta: 1e-05\n",
|
131 |
| - "runtime = 0.02 seconds\n", |
132 |
| - "Avg. episode reward: 0.01\n", |
| 110 | + "runtime = 0.01 seconds\n", |
| 111 | + "Avg. episode reward: -0.1\n", |
133 | 112 | "###################\n",
|
134 | 113 | "running VI with gamma: 0.99 n_iters: 500 theta: 0.001\n",
|
135 | 114 | "runtime = 0.01 seconds\n",
|
136 |
| - "Avg. episode reward: 0.13\n", |
| 115 | + "Avg. episode reward: -0.1\n", |
137 | 116 | "###################\n",
|
138 | 117 | "running VI with gamma: 0.99 n_iters: 500 theta: 1e-05\n",
|
139 | 118 | "runtime = 0.01 seconds\n",
|
140 |
| - "Avg. episode reward: 0.07\n", |
| 119 | + "Avg. episode reward: -0.04\n", |
141 | 120 | "###################\n"
|
142 | 121 | ]
|
143 | 122 | }
|
|
164 | 143 | "id": "c329484d-db1b-48fe-b7a0-b035ad5356bb",
|
165 | 144 | "metadata": {},
|
166 | 145 | "source": [
|
167 |
| - "RL algorithms SARSA and Q-learning have callback hooks for episode number, begin, end, and env. step. To create a callback, override one of the callback functions in the child class MyCallbacks. Or, you can use the add_to decorator and define the override outside of the class definition. For example, print the episode number every 1000 episodes." |
| 146 | + "RL algorithms SARSA and Q-learning have callback hooks for episode number, begin, end, and env. step. To create a callback, override one of the callback functions in the child class MyCallbacks. Or, you can use the add_to decorator to define the override outside of the class definition. For example, print the episode number every 1000 episodes." |
168 | 147 | ]
|
169 | 148 | },
|
170 | 149 | {
|
171 | 150 | "cell_type": "code",
|
172 |
| - "execution_count": 11, |
| 151 | + "execution_count": 4, |
173 | 152 | "id": "787ce4d3-e2b6-459f-94a6-ae4201fa091f",
|
174 | 153 | "metadata": {},
|
175 | 154 | "outputs": [
|
176 | 155 | {
|
177 | 156 | "name": "stderr",
|
178 | 157 | "output_type": "stream",
|
179 | 158 | "text": [
|
180 |
| - " 16%|█▌ | 1580/10000 [00:00<00:01, 7893.83it/s]" |
| 159 | + " 17%|█▋ | 1661/10000 [00:00<00:01, 8314.69it/s]" |
181 | 160 | ]
|
182 | 161 | },
|
183 | 162 | {
|
|
192 | 171 | "name": "stderr",
|
193 | 172 | "output_type": "stream",
|
194 | 173 | "text": [
|
195 |
| - " 32%|███▏ | 3177/10000 [00:00<00:00, 7868.57it/s]" |
| 174 | + " 33%|███▎ | 3302/10000 [00:00<00:00, 8044.59it/s]" |
196 | 175 | ]
|
197 | 176 | },
|
198 | 177 | {
|
|
207 | 186 | "name": "stderr",
|
208 | 187 | "output_type": "stream",
|
209 | 188 | "text": [
|
210 |
| - " 55%|█████▌ | 5530/10000 [00:00<00:00, 7772.50it/s]" |
| 189 | + " 49%|████▉ | 4911/10000 [00:00<00:00, 7923.74it/s]" |
211 | 190 | ]
|
212 | 191 | },
|
213 | 192 | {
|
|
222 | 201 | "name": "stderr",
|
223 | 202 | "output_type": "stream",
|
224 | 203 | "text": [
|
225 |
| - " 71%|███████ | 7053/10000 [00:00<00:00, 7329.64it/s]" |
| 204 | + " 73%|███████▎ | 7295/10000 [00:00<00:00, 7878.84it/s]" |
226 | 205 | ]
|
227 | 206 | },
|
228 | 207 | {
|
|
237 | 216 | "name": "stderr",
|
238 | 217 | "output_type": "stream",
|
239 | 218 | "text": [
|
240 |
| - " 93%|█████████▎| 9256/10000 [00:01<00:00, 7189.14it/s]" |
| 219 | + " 88%|████████▊ | 8843/10000 [00:01<00:00, 7328.18it/s]" |
241 | 220 | ]
|
242 | 221 | },
|
243 | 222 | {
|
|
259 | 238 | "name": "stdout",
|
260 | 239 | "output_type": "stream",
|
261 | 240 | "text": [
|
262 |
| - "runtime = 1.35 seconds\n" |
| 241 | + "runtime = 1.31 seconds\n" |
263 | 242 | ]
|
264 | 243 | },
|
265 | 244 | {
|
|
271 | 250 | }
|
272 | 251 | ],
|
273 | 252 | "source": [
|
274 |
| - "from utils.decorators import add_to\n", |
275 |
| - "from utils.callbacks import MyCallbacks\n", |
276 |
| - "from algorithms.rl import RL\n", |
277 |
| - "from utils.blackjack_wrapper import BlackjackWrapper\n", |
| 253 | + "from bettermdptools.utils.decorators import add_to\n", |
| 254 | + "from bettermdptools.utils.callbacks import MyCallbacks\n", |
| 255 | + "from bettermdptools.algorithms.rl import RL\n", |
| 256 | + "from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n", |
278 | 257 | "\n",
|
279 | 258 | "base_env = gym.make('Blackjack-v1', render_mode=None)\n",
|
280 | 259 | "blackjack = BlackjackWrapper(base_env)\n",
|
|
0 commit comments