Skip to content

Commit 95be3bc

Browse files
authored
Merge pull request #14 from HumanCompatibleAI/layer_nmf_types
Add type annotations to stuff
2 parents e61bf0e + f6637eb commit 95be3bc

File tree

4 files changed

+41
-15
lines changed

4 files changed

+41
-15
lines changed

ci/code_checks.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ SRC_FILES=(src/ tests/ setup.py)
66
set -x # echo commands
77
set -e # quit immediately on error
88

9+
pytype ${SRC_FILES[@]}
910
flake8 ${SRC_FILES[@]}
10-
black --check ${SRC_FILES[@]}
11+
black --check ${SRC_FILES[@]}

src/reward_preprocessing/common/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch as th
99
from torch import nn as nn
1010
from torch.utils import data as torch_data
11+
import vegans.utils
1112

1213

1314
def make_transition_to_tensor(num_acts):
@@ -206,3 +207,9 @@ def forward(self, transition_tensor: th.Tensor) -> th.Tensor:
206207

207208
dones = th.zeros_like(obs[:, 0])
208209
return self.rew_net(state=obs, action=act, next_state=next_obs, done=dones)
210+
211+
212+
def save_loss_plots(losses, save_dir):
213+
"""Save plots of generator/adversary losses over training."""
214+
fig, _ = vegans.utils.plot_losses(losses, show=False)
215+
fig.savefig(Path(save_dir) / "loss_fig.png")

src/reward_preprocessing/scripts/train_gan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def train_gan(
101101
# save samples, return losses, save plot of losses
102102
samples, losses = gan.get_training_results()
103103
utils.save_loss_plots(losses, gan.folder)
104-
utils.visualize_samples(samples, num_acts, gan.folder)
104+
utils.visualize_samples(samples, gan.folder)
105105
return losses
106106

107107

src/reward_preprocessing/vis/reward_vis.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Port of lucid.scratch.rl_util to PyTorch. APL2.0 licensed."""
22
from functools import reduce
33
import logging
4-
from typing import List, Optional
4+
from typing import Callable, Dict, List, Optional, Union
55

66
import lucent.optvis.param as param
77
import lucent.optvis.render as render
@@ -15,7 +15,13 @@
1515
import reward_preprocessing.vis.objectives as objectives_rfi
1616

1717

18-
def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=None):
18+
def argmax_nd(
19+
x: np.ndarray,
20+
axes: List[int],
21+
*,
22+
max_rep: Union[int, float] = np.inf,
23+
max_rep_strict: Optional[bool] = None,
24+
):
1925
"""Return the indices of the maximum value along the given axes.
2026
2127
Args:
@@ -37,7 +43,7 @@ def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=
3743
if max_rep <= 0:
3844
raise ValueError("max_rep must be greater than 0.")
3945
if max_rep_strict is None and not np.isinf(max_rep):
40-
raise ValueError("if max_rep_strict is not set if max_rep must be infinite.")
46+
raise ValueError("if max_rep_strict is not set, then max_rep must be infinite.")
4147
# Make it so the axes we want to find the maximum along are the first ones...
4248
perm = list(range(len(x.shape)))
4349
for axis in reversed(axes):
@@ -94,14 +100,14 @@ class LayerNMF:
94100

95101
def __init__(
96102
self,
97-
model,
98-
layer_name,
99-
model_inputs_preprocess,
100-
model_inputs_full=None,
103+
model: th.nn.Module,
104+
layer_name: str,
105+
model_inputs_preprocess: th.Tensor,
106+
model_inputs_full: Optional[th.Tensor] = None,
101107
features: Optional[int] = 10,
102108
*,
103109
attr_layer_name: Optional[str] = None,
104-
attr_opts={"integrate_steps": 10},
110+
attr_opts: Dict[str, int] = {"integrate_steps": 10},
105111
activation_fn: Optional[str] = None,
106112
):
107113
"""Use Non-negative matrix factorization dimensionality reduction to then do
@@ -231,9 +237,9 @@ def vis_traditional(
231237
self,
232238
feature_list=None,
233239
*,
234-
transforms=[transform.jitter(2)],
235-
l2_coeff=0.0,
236-
l2_layer_name=None,
240+
transforms: List[Callable[[th.Tensor], th.Tensor]] = [transform.jitter(2)],
241+
l2_coeff: float = 0.0,
242+
l2_layer_name: Optional[str] = None,
237243
):
238244
if feature_list is None:
239245
# Feature dim is at index 1
@@ -329,7 +335,14 @@ def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
329335
slice_w = slice(int(round(left_w)), int(round(right_w)))
330336
return self.padded_obses[obs_index, :, slice_h, slice_w]
331337

332-
def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
338+
def vis_dataset(
339+
self,
340+
feature: Union[int, List[int]],
341+
*,
342+
subdiv_mult=1,
343+
expand_mult=1,
344+
top_frac: float = 0.1,
345+
):
333346
"""Visualize a dataset of patches that maximize a given feature.
334347
335348
Args:
@@ -394,7 +407,12 @@ def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
394407
)
395408

396409
def vis_dataset_thumbnail(
397-
self, feature, *, num_mult=1, expand_mult=1, max_rep=None
410+
self,
411+
feature: Union[int, List[int]],
412+
*,
413+
num_mult: int = 1,
414+
expand_mult: int = 1,
415+
max_rep: Optional[Union[int, float]] = None,
398416
):
399417
"""Visualize a dataset of patches that maximize a given feature.
400418

0 commit comments

Comments
 (0)