Skip to content

Commit 607551b

Browse files
committed
merge in main
1 parent 3566f6b commit 607551b

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

ci/code_checks.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ SRC_FILES=(src/ tests/ setup.py)
66
set -x # echo commands
77
set -e # quit immediately on error
88

9+
pytype ${SRC_FILES[@]}
910
flake8 ${SRC_FILES[@]}
10-
black --check ${SRC_FILES[@]}
11+
black --check ${SRC_FILES[@]}

src/reward_preprocessing/vis/reward_vis.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Port of lucid.scratch.rl_util to PyTorch. APL2.0 licensed."""
22
from functools import reduce
33
import logging
4-
from typing import List, Optional
4+
from typing import Callable, Dict, List, Optional, Union
55

66
from lucent.optvis.objectives import handle_batch, wrap_objective
77
import lucent.optvis.param as param
@@ -16,7 +16,13 @@
1616
import reward_preprocessing.vis.objectives as objectives_rfi
1717

1818

19-
def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=None):
19+
def argmax_nd(
20+
x: np.ndarray,
21+
axes: List[int],
22+
*,
23+
max_rep: Union[int, float] = np.inf,
24+
max_rep_strict: Optional[bool] = None,
25+
):
2026
"""Return the indices of the maximum value along the given axes.
2127
2228
Args:
@@ -38,7 +44,7 @@ def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=
3844
if max_rep <= 0:
3945
raise ValueError("max_rep must be greater than 0.")
4046
if max_rep_strict is None and not np.isinf(max_rep):
41-
raise ValueError("if max_rep_strict is not set if max_rep must be infinite.")
47+
raise ValueError("if max_rep_strict is not set, then max_rep must be infinite.")
4248
# Make it so the axes we want to find the maximum along are the first ones...
4349
perm = list(range(len(x.shape)))
4450
for axis in reversed(axes):
@@ -106,14 +112,14 @@ class LayerNMF:
106112

107113
def __init__(
108114
self,
109-
model,
110-
layer_name,
111-
model_inputs_preprocess,
112-
model_inputs_full=None,
115+
model: th.nn.Module,
116+
layer_name: str,
117+
model_inputs_preprocess: th.Tensor,
118+
model_inputs_full: Optional[th.Tensor] = None,
113119
features: Optional[int] = 10,
114120
*,
115121
attr_layer_name: Optional[str] = None,
116-
attr_opts={"integrate_steps": 10},
122+
attr_opts: Dict[str, int] = {"integrate_steps": 10},
117123
activation_fn: Optional[str] = None,
118124
):
119125
"""Use Non-negative matrix factorization dimensionality reduction to then do
@@ -244,9 +250,9 @@ def vis_traditional(
244250
self,
245251
feature_list=None,
246252
*,
247-
transforms=[transform.jitter(2)],
248-
l2_coeff=0.0,
249-
l2_layer_name=None,
253+
transforms: List[Callable[[th.Tensor], th.Tensor]] = [transform.jitter(2)],
254+
l2_coeff: float = 0.0,
255+
l2_layer_name: Optional[str] = None,
250256
) -> np.ndarray:
251257
if feature_list is None:
252258
# Feature dim is at index 1
@@ -341,7 +347,14 @@ def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
341347
slice_w = slice(int(round(left_w)), int(round(right_w)))
342348
return self.padded_obses[obs_index, :, slice_h, slice_w]
343349

344-
def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
350+
def vis_dataset(
351+
self,
352+
feature: Union[int, List[int]],
353+
*,
354+
subdiv_mult=1,
355+
expand_mult=1,
356+
top_frac: float = 0.1,
357+
):
345358
"""Visualize a dataset of patches that maximize a given feature.
346359
347360
Args:
@@ -406,7 +419,12 @@ def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
406419
)
407420

408421
def vis_dataset_thumbnail(
409-
self, feature, *, num_mult=1, expand_mult=1, max_rep=None
422+
self,
423+
feature: Union[int, List[int]],
424+
*,
425+
num_mult: int = 1,
426+
expand_mult: int = 1,
427+
max_rep: Optional[Union[int, float]] = None,
410428
):
411429
"""Visualize a dataset of patches that maximize a given feature.
412430

0 commit comments

Comments
 (0)