|
107 | 107 | "import numpy as np\n",
|
108 | 108 | "import torch as th\n",
|
109 | 109 | "from collections import OrderedDict\n",
|
| 110 | + "\n", |
110 | 111 | "# from lucid.modelzoo.vision_base import Model\n",
|
111 | 112 | "# from lucid.scratch.rl_util import *\n",
|
112 | 113 | "# get_abbreviator defined during setup\n",
|
113 |
| - "from reward_preprocessing.ext.notebook_util import CategoricalPolicyGM, ImpalaModel, get_abbreviator\n", |
| 114 | + "from reward_preprocessing.ext.notebook_util import (\n", |
| 115 | + " CategoricalPolicyGM,\n", |
| 116 | + " ImpalaModel,\n", |
| 117 | + " get_abbreviator,\n", |
| 118 | + ")\n", |
114 | 119 | "from reward_preprocessing.vis.reward_vis import LayerNMF\n",
|
115 | 120 | "from reward_preprocessing.vis.util import zoom_to"
|
116 | 121 | ],
|
|
157 | 162 | "model_state_dict = th.load(Path(model_path).expanduser(), device)[\"model_state_dict\"]\n",
|
158 | 163 | "\n",
|
159 | 164 | "embedder = ImpalaModel(in_channels=3)\n",
|
160 |
| - "model = CategoricalPolicyGM(embedder=embedder, action_size=model_state_dict[\"fc_policy.weight\"].shape[0])\n", |
| 165 | + "model = CategoricalPolicyGM(\n", |
| 166 | + " embedder=embedder, action_size=model_state_dict[\"fc_policy.weight\"].shape[0]\n", |
| 167 | + ")\n", |
161 | 168 | "# Load data\n",
|
162 | 169 | "model.load_state_dict(model_state_dict)\n",
|
163 | 170 | "\n",
|
164 |
| - "value_function_name = 'fc_value'\n", |
| 171 | + "value_function_name = \"fc_value\"\n", |
165 | 172 | "\n",
|
166 | 173 | "# Load trajectories for dataset visualization.\n",
|
167 |
| - "trajectories = demonstrations.load_expert_trajs(str(Path(trajectories_path).expanduser()), n_expert_demos=None)\n", |
| 174 | + "trajectories = demonstrations.load_expert_trajs(\n", |
| 175 | + " str(Path(trajectories_path).expanduser()), n_expert_demos=None\n", |
| 176 | + ")\n", |
168 | 177 | "trajectories = flatten_trajectories(trajectories)\n",
|
169 | 178 | "\n",
|
170 | 179 | "# Get observations from trajectories.\n",
|
|
174 | 183 | "\n",
|
175 | 184 | "layer_names = get_model_layers(model)\n",
|
176 | 185 | "abbreviator = get_abbreviator(layer_names)\n",
|
177 |
| - "layer_names = OrderedDict(\n", |
178 |
| - " [(name[abbreviator], name) for name in layer_names]\n", |
179 |
| - ")\n" |
| 186 | + "layer_names = OrderedDict([(name[abbreviator], name) for name in layer_names])" |
180 | 187 | ],
|
181 | 188 | "execution_count": 2,
|
182 | 189 | "outputs": []
|
|
399 | 406 | },
|
400 | 407 | "source": [
|
401 | 408 | "model.eval()\n",
|
402 |
| - "layer = 'embedder_relu_after_convs'\n", |
| 409 | + "layer = \"embedder_relu_after_convs\"\n", |
403 | 410 | "# value_function_name = None\n",
|
404 | 411 | "# can take a couple of minutes\n",
|
405 | 412 | "# for the paper, we use observations[:], but this requires more memory\n",
|
406 |
| - "nmf = LayerNMF(model, layer, observations[:1024], features=None, attr_layer_name=value_function_name)" |
| 413 | + "nmf = LayerNMF(\n", |
| 414 | + " model,\n", |
| 415 | + " layer,\n", |
| 416 | + " observations[:1024],\n", |
| 417 | + " features=None,\n", |
| 418 | + " attr_layer_name=value_function_name,\n", |
| 419 | + ")" |
407 | 420 | ],
|
408 | 421 | "execution_count": 12,
|
409 | 422 | "outputs": []
|
|
579 | 592 | },
|
580 | 593 | "source": [
|
581 | 594 | "# Show expects channels last, unlike the rest of lucent. Therefore we need to transpose here.\n",
|
582 |
| - "show([zoom_to(nmf.vis_dataset_thumbnail(i, num_mult=4, expand_mult=4, max_rep=np.inf)[0], 200).transpose(1,2,0) for i in range(nmf.features)])" |
| 595 | + "show(\n", |
| 596 | + " [\n", |
| 597 | + " zoom_to(\n", |
| 598 | + " nmf.vis_dataset_thumbnail(i, num_mult=4, expand_mult=4, max_rep=np.inf)[0],\n", |
| 599 | + " 200,\n", |
| 600 | + " ).transpose(1, 2, 0)\n", |
| 601 | + " for i in range(nmf.features)\n", |
| 602 | + " ]\n", |
| 603 | + ")" |
583 | 604 | ],
|
584 | 605 | "execution_count": 91,
|
585 | 606 | "outputs": [
|
|
784 | 805 | }
|
785 | 806 | },
|
786 | 807 | "source": [
|
787 |
| - "traj = trajectories['observations'][0][76:84]\n", |
788 |
| - "attr = get_attr(model, value_function_name, layer_names['2b'], traj, integrate_steps=10)\n", |
| 808 | + "traj = trajectories[\"observations\"][0][76:84]\n", |
| 809 | + "attr = get_attr(model, value_function_name, layer_names[\"2b\"], traj, integrate_steps=10)\n", |
789 | 810 | "attr.shape"
|
790 | 811 | ],
|
791 | 812 | "execution_count": 16,
|
|
840 | 861 | }
|
841 | 862 | },
|
842 | 863 | "source": [
|
843 |
| - "attr_reduced = nmf.transform(np.maximum(attr, 0)) - nmf.transform(np.maximum(-attr, 0)) # transform the positive and negative parts separately\n", |
| 864 | + "attr_reduced = nmf.transform(np.maximum(attr, 0)) - nmf.transform(\n", |
| 865 | + " np.maximum(-attr, 0)\n", |
| 866 | + ") # transform the positive and negative parts separately\n", |
844 | 867 | "nmf_norms = nmf.channel_dirs.sum(-1)\n",
|
845 |
| - "attr_reduced *= nmf_norms[None, None, None] # multiply by the norms of the NMF directions, since the magnitudes of the NMF directions are not relevant\n", |
846 |
| - "attr_reduced /= np.median(attr_reduced.max(axis=(-3, -2, -1))) # globally normalize by the median max value to make the visualization balanced (a bit of a hack)\n", |
| 868 | + "attr_reduced *= nmf_norms[\n", |
| 869 | + " None, None, None\n", |
| 870 | + "] # multiply by the norms of the NMF directions, since the magnitudes of the NMF directions are not relevant\n", |
| 871 | + "attr_reduced /= np.median(\n", |
| 872 | + " attr_reduced.max(axis=(-3, -2, -1))\n", |
| 873 | + ") # globally normalize by the median max value to make the visualization balanced (a bit of a hack)\n", |
847 | 874 | "attr_reduced.shape"
|
848 | 875 | ],
|
849 | 876 | "execution_count": 17,
|
|
1315 | 1342 | }
|
1316 | 1343 | },
|
1317 | 1344 | "source": [
|
1318 |
| - "kernel_name = layer_names[\"3a\"].replace(\"Relu\", \"conv2d/kernel\") # name of tensor of convolutional kernel of next layer\n", |
| 1345 | + "kernel_name = layer_names[\"3a\"].replace(\n", |
| 1346 | + " \"Relu\", \"conv2d/kernel\"\n", |
| 1347 | + ") # name of tensor of convolutional kernel of next layer\n", |
1319 | 1348 | "kernel = editor[kernel_name]\n",
|
1320 |
| - "saw_dir = nmf.channel_dirs[0][None, None, :, None] # first NMF direction, corresponding to saw obstacle\n", |
| 1349 | + "saw_dir = nmf.channel_dirs[0][\n", |
| 1350 | + " None, None, :, None\n", |
| 1351 | + "] # first NMF direction, corresponding to saw obstacle\n", |
1321 | 1352 | "saw_dir /= np.linalg.norm(saw_dir)\n",
|
1322 | 1353 | "# the kernel is left-multiplied by the activations from the previous layer, so we left-multiply the kernel by the projection matrix\n",
|
1323 |
| - "kernel = kernel - saw_dir * (saw_dir * kernel).sum(axis=-2, keepdims=True) # equivalently: kernel - saw_dir @ saw_dir.transpose((0, 1, 3, 2)) @ kernel\n", |
| 1354 | + "kernel = kernel - saw_dir * (saw_dir * kernel).sum(\n", |
| 1355 | + " axis=-2, keepdims=True\n", |
| 1356 | + ") # equivalently: kernel - saw_dir @ saw_dir.transpose((0, 1, 3, 2)) @ kernel\n", |
1324 | 1357 | "editor[kernel_name] = kernel\n",
|
1325 | 1358 | "# note: this is not quite the same as the edit made for the paper, since we only used 1024 observations for the NMF calculation here"
|
1326 | 1359 | ],
|
|
1415 | 1448 | }
|
1416 | 1449 | },
|
1417 | 1450 | "source": [
|
1418 |
| - "traj = trajectories['observations'][0][76:84]\n", |
1419 |
| - "attr = get_attr(edited_model, value_function_name, layer_names['2b'], traj, integrate_steps=10)\n", |
| 1451 | + "traj = trajectories[\"observations\"][0][76:84]\n", |
| 1452 | + "attr = get_attr(\n", |
| 1453 | + " edited_model, value_function_name, layer_names[\"2b\"], traj, integrate_steps=10\n", |
| 1454 | + ")\n", |
1420 | 1455 | "attr_reduced = nmf.transform(np.maximum(attr, 0)) - nmf.transform(np.maximum(-attr, 0))\n",
|
1421 | 1456 | "nmf_norms = nmf.channel_dirs.sum(-1)\n",
|
1422 | 1457 | "attr_reduced *= nmf_norms[None, None, None]\n",
|
|
0 commit comments