Skip to content

Commit 5fa5b17

Browse files
committed
update notebooks
1 parent a24cb90 commit 5fa5b17

File tree

5 files changed

+66
-87
lines changed

5 files changed

+66
-87
lines changed

notebooks/blackjack.ipynb

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@
2525
},
2626
{
2727
"cell_type": "code",
28-
"execution_count": 10,
28+
"execution_count": 1,
2929
"id": "900376a8-792f-41bc-988c-2bc23ff2d7d4",
3030
"metadata": {},
3131
"outputs": [],
3232
"source": [
3333
"import gymnasium as gym\n",
34-
"from utils.blackjack_wrapper import BlackjackWrapper\n",
35-
"from utils.test_env import TestEnv\n",
36-
"from algorithms.planner import Planner\n",
37-
"from algorithms.rl import RL\n",
34+
"from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n",
35+
"from bettermdptools.utils.test_env import TestEnv\n",
36+
"from bettermdptools.algorithms.planner import Planner\n",
37+
"from bettermdptools.algorithms.rl import RL\n",
3838
"import numpy as np"
3939
]
4040
},
4141
{
4242
"cell_type": "code",
43-
"execution_count": 14,
43+
"execution_count": 2,
4444
"id": "265aae24-a1b3-4400-8bcc-8da16e9a4612",
4545
"metadata": {},
4646
"outputs": [
@@ -49,22 +49,15 @@
4949
"output_type": "stream",
5050
"text": [
5151
"runtime = 0.03 seconds\n",
52-
"0.2\n"
53-
]
54-
},
55-
{
56-
"name": "stderr",
57-
"output_type": "stream",
58-
"text": [
59-
" "
52+
"-0.03\n"
6053
]
6154
},
6255
{
6356
"name": "stdout",
6457
"output_type": "stream",
6558
"text": [
66-
"runtime = 1.34 seconds\n",
67-
"0.01\n"
59+
"runtime = 1.30 seconds\n",
60+
"0.1\n"
6861
]
6962
},
7063
{

notebooks/frozen_lake.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"outputs": [],
2929
"source": [
3030
"import gymnasium as gym\n",
31-
"from algorithms.planner import Planner\n",
32-
"from utils.plots import Plots"
31+
"from bettermdptools.algorithms.planner import Planner\n",
32+
"from bettermdptools.utils.plots import Plots"
3333
]
3434
},
3535
{
@@ -42,7 +42,7 @@
4242
"name": "stdout",
4343
"output_type": "stream",
4444
"text": [
45-
"runtime = 0.51 seconds\n"
45+
"runtime = 0.42 seconds\n"
4646
]
4747
},
4848
{

notebooks/other_utilities.ipynb

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,22 @@
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": 10,
23+
"execution_count": 1,
2424
"id": "4815068f-bb52-47fb-9d35-3a340663c758",
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
2828
"import gymnasium as gym\n",
29-
"from utils.test_env import TestEnv\n",
30-
"from utils.grid_search import GridSearch\n",
31-
"from utils.blackjack_wrapper import BlackjackWrapper\n",
32-
"from algorithms.rl import RL\n",
33-
"from algorithms.planner import Planner"
29+
"from bettermdptools.utils.test_env import TestEnv\n",
30+
"from bettermdptools.utils.grid_search import GridSearch\n",
31+
"from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n",
32+
"from bettermdptools.algorithms.rl import RL\n",
33+
"from bettermdptools.algorithms.planner import Planner"
3434
]
3535
},
3636
{
3737
"cell_type": "code",
38-
"execution_count": 4,
38+
"execution_count": 2,
3939
"id": "448dc393-ba4f-412d-ac89-6c71a22cfeaf",
4040
"metadata": {},
4141
"outputs": [
@@ -50,51 +50,30 @@
5050
"name": "stdout",
5151
"output_type": "stream",
5252
"text": [
53-
"runtime = 0.67 seconds\n",
53+
"runtime = 0.64 seconds\n",
5454
"Avg. episode reward: 0.0\n",
5555
"###################\n",
5656
"running q_learning with gamma: 0.99 epsilon decay: 0.9 iterations: 5000\n"
5757
]
5858
},
59-
{
60-
"name": "stderr",
61-
"output_type": "stream",
62-
"text": [
63-
" \r"
64-
]
65-
},
6659
{
6760
"name": "stdout",
6861
"output_type": "stream",
6962
"text": [
70-
"runtime = 6.56 seconds\n",
63+
"runtime = 6.49 seconds\n",
7164
"Avg. episode reward: 0.0\n",
7265
"###################\n",
7366
"running q_learning with gamma: 0.99 epsilon decay: 0.9 iterations: 50000\n"
7467
]
7568
},
76-
{
77-
"name": "stderr",
78-
"output_type": "stream",
79-
"text": [
80-
" "
81-
]
82-
},
8369
{
8470
"name": "stdout",
8571
"output_type": "stream",
8672
"text": [
87-
"runtime = 53.30 seconds\n",
88-
"Avg. episode reward: 0.83\n",
73+
"runtime = 54.57 seconds\n",
74+
"Avg. episode reward: 0.88\n",
8975
"###################\n"
9076
]
91-
},
92-
{
93-
"name": "stderr",
94-
"output_type": "stream",
95-
"text": [
96-
"\r"
97-
]
9877
}
9978
],
10079
"source": [
@@ -107,7 +86,7 @@
10786
},
10887
{
10988
"cell_type": "code",
110-
"execution_count": 6,
89+
"execution_count": 3,
11190
"id": "60161c93-6183-4de6-a937-7c12b9742fb7",
11291
"metadata": {},
11392
"outputs": [
@@ -116,28 +95,28 @@
11695
"output_type": "stream",
11796
"text": [
11897
"running VI with gamma: 0.7 n_iters: 500 theta: 0.001\n",
119-
"runtime = 0.02 seconds\n",
120-
"Avg. episode reward: 0.01\n",
98+
"runtime = 0.01 seconds\n",
99+
"Avg. episode reward: -0.16\n",
121100
"###################\n",
122101
"running VI with gamma: 0.7 n_iters: 500 theta: 1e-05\n",
123102
"runtime = 0.02 seconds\n",
124-
"Avg. episode reward: -0.07\n",
103+
"Avg. episode reward: 0.18\n",
125104
"###################\n",
126105
"running VI with gamma: 0.9 n_iters: 500 theta: 0.001\n",
127106
"runtime = 0.01 seconds\n",
128-
"Avg. episode reward: -0.16\n",
107+
"Avg. episode reward: -0.06\n",
129108
"###################\n",
130109
"running VI with gamma: 0.9 n_iters: 500 theta: 1e-05\n",
131-
"runtime = 0.02 seconds\n",
132-
"Avg. episode reward: 0.01\n",
110+
"runtime = 0.01 seconds\n",
111+
"Avg. episode reward: -0.1\n",
133112
"###################\n",
134113
"running VI with gamma: 0.99 n_iters: 500 theta: 0.001\n",
135114
"runtime = 0.01 seconds\n",
136-
"Avg. episode reward: 0.13\n",
115+
"Avg. episode reward: -0.1\n",
137116
"###################\n",
138117
"running VI with gamma: 0.99 n_iters: 500 theta: 1e-05\n",
139118
"runtime = 0.01 seconds\n",
140-
"Avg. episode reward: 0.07\n",
119+
"Avg. episode reward: -0.04\n",
141120
"###################\n"
142121
]
143122
}
@@ -164,20 +143,20 @@
164143
"id": "c329484d-db1b-48fe-b7a0-b035ad5356bb",
165144
"metadata": {},
166145
"source": [
167-
"RL algorithms SARSA and Q-learning have callback hooks for episode number, begin, end, and env. step. To create a callback, override one of the callback functions in the child class MyCallbacks. Or, you can use the add_to decorator and define the override outside of the class definition. For example, print the episode number every 1000 episodes."
146+
"RL algorithms SARSA and Q-learning have callback hooks for episode number, begin, end, and env. step. To create a callback, override one of the callback functions in the child class MyCallbacks. Or, you can use the add_to decorator to define the override outside of the class definition. For example, print the episode number every 1000 episodes."
168147
]
169148
},
170149
{
171150
"cell_type": "code",
172-
"execution_count": 11,
151+
"execution_count": 4,
173152
"id": "787ce4d3-e2b6-459f-94a6-ae4201fa091f",
174153
"metadata": {},
175154
"outputs": [
176155
{
177156
"name": "stderr",
178157
"output_type": "stream",
179158
"text": [
180-
" 16%|█ | 1580/10000 [00:00<00:01, 7893.83it/s]"
159+
" 17%|█ | 1661/10000 [00:00<00:01, 8314.69it/s]"
181160
]
182161
},
183162
{
@@ -192,7 +171,7 @@
192171
"name": "stderr",
193172
"output_type": "stream",
194173
"text": [
195-
" 32%|███ | 3177/10000 [00:00<00:00, 7868.57it/s]"
174+
" 33%|███ | 3302/10000 [00:00<00:00, 8044.59it/s]"
196175
]
197176
},
198177
{
@@ -207,7 +186,7 @@
207186
"name": "stderr",
208187
"output_type": "stream",
209188
"text": [
210-
" 55%|█████▌ | 5530/10000 [00:00<00:00, 7772.50it/s]"
189+
" 49%|████ | 4911/10000 [00:00<00:00, 7923.74it/s]"
211190
]
212191
},
213192
{
@@ -222,7 +201,7 @@
222201
"name": "stderr",
223202
"output_type": "stream",
224203
"text": [
225-
" 71%|███████ | 7053/10000 [00:00<00:00, 7329.64it/s]"
204+
" 73%|███████▎ | 7295/10000 [00:00<00:00, 7878.84it/s]"
226205
]
227206
},
228207
{
@@ -237,7 +216,7 @@
237216
"name": "stderr",
238217
"output_type": "stream",
239218
"text": [
240-
" 93%|█████████▎| 9256/10000 [00:01<00:00, 7189.14it/s]"
219+
" 88%|████████▊ | 8843/10000 [00:01<00:00, 7328.18it/s]"
241220
]
242221
},
243222
{
@@ -259,7 +238,7 @@
259238
"name": "stdout",
260239
"output_type": "stream",
261240
"text": [
262-
"runtime = 1.35 seconds\n"
241+
"runtime = 1.31 seconds\n"
263242
]
264243
},
265244
{
@@ -271,10 +250,10 @@
271250
}
272251
],
273252
"source": [
274-
"from utils.decorators import add_to\n",
275-
"from utils.callbacks import MyCallbacks\n",
276-
"from algorithms.rl import RL\n",
277-
"from utils.blackjack_wrapper import BlackjackWrapper\n",
253+
"from bettermdptools.utils.decorators import add_to\n",
254+
"from bettermdptools.utils.callbacks import MyCallbacks\n",
255+
"from bettermdptools.algorithms.rl import RL\n",
256+
"from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n",
278257
"\n",
279258
"base_env = gym.make('Blackjack-v1', render_mode=None)\n",
280259
"blackjack = BlackjackWrapper(base_env)\n",

notebooks/plots.ipynb

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
"outputs": [],
2727
"source": [
2828
"import gymnasium as gym\n",
29-
"from utils.test_env import TestEnv\n",
30-
"from algorithms.rl import RL\n",
31-
"from algorithms.planner import Planner\n",
32-
"from utils.plots import Plots\n",
33-
"from utils.blackjack_wrapper import BlackjackWrapper\n",
29+
"from bettermdptools.utils.test_env import TestEnv\n",
30+
"from bettermdptools.algorithms.rl import RL\n",
31+
"from bettermdptools.algorithms.planner import Planner\n",
32+
"from bettermdptools.utils.plots import Plots\n",
33+
"from bettermdptools.utils.blackjack_wrapper import BlackjackWrapper\n",
3434
"import numpy as np\n",
3535
"import seaborn as sns\n",
3636
"import matplotlib.pyplot as plt"
@@ -54,7 +54,7 @@
5454
"name": "stdout",
5555
"output_type": "stream",
5656
"text": [
57-
"runtime = 0.69 seconds\n"
57+
"runtime = 0.57 seconds\n"
5858
]
5959
},
6060
{
@@ -117,7 +117,7 @@
117117
},
118118
{
119119
"cell_type": "code",
120-
"execution_count": 5,
120+
"execution_count": 4,
121121
"id": "5a1642ef-4a93-4d08-96f7-d62b58a63e58",
122122
"metadata": {},
123123
"outputs": [
@@ -150,7 +150,7 @@
150150
},
151151
{
152152
"cell_type": "code",
153-
"execution_count": 11,
153+
"execution_count": 5,
154154
"id": "8b7c01ed-cb3e-45d5-9f88-c8733ae3c6f0",
155155
"metadata": {},
156156
"outputs": [
@@ -197,12 +197,12 @@
197197
"id": "31f641c6-d465-4c49-a7bc-770a0aaebe12",
198198
"metadata": {},
199199
"source": [
200-
"### Using the add_to decorator to make changes on the fly. "
200+
"### Customized policy map using the add_to decorator. "
201201
]
202202
},
203203
{
204204
"cell_type": "code",
205-
"execution_count": 12,
205+
"execution_count": 7,
206206
"id": "e4c8ae11",
207207
"metadata": {},
208208
"outputs": [
@@ -218,7 +218,7 @@
218218
}
219219
],
220220
"source": [
221-
"from utils.decorators import add_to\n",
221+
"from bettermdptools.utils.decorators import add_to\n",
222222
"\n",
223223
"@add_to(Plots)\n",
224224
"@staticmethod\n",

notebooks/taxi.ipynb

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
"outputs": [],
2929
"source": [
3030
"import gymnasium as gym\n",
31-
"from algorithms.planner import Planner\n",
32-
"from utils.plots import Plots\n",
33-
"from utils.test_env import TestEnv"
31+
"from bettermdptools.algorithms.planner import Planner\n",
32+
"from bettermdptools.utils.plots import Plots\n",
33+
"from bettermdptools.utils.test_env import TestEnv"
3434
]
3535
},
3636
{
@@ -44,7 +44,14 @@
4444
"output_type": "stream",
4545
"text": [
4646
"runtime = 0.04 seconds\n",
47-
"[ 5. 10. 6. 4. 11. 9. 8. 6. 8. 6.]\n"
47+
"[12. 8. 11. 8. 8. 6. 8. 7. 10. 8.]\n"
48+
]
49+
},
50+
{
51+
"name": "stdout",
52+
"output_type": "stream",
53+
"text": [
54+
""
4855
]
4956
}
5057
],

0 commit comments

Comments
 (0)