1
1
"""Port of lucid.scratch.rl_util to PyTorch. APL2.0 licensed."""
2
2
from functools import reduce
3
3
import logging
4
- from typing import List , Optional
4
+ from typing import Callable , Dict , List , Optional , Union
5
5
6
6
import lucent .optvis .param as param
7
7
import lucent .optvis .render as render
15
15
import reward_preprocessing .vis .objectives as objectives_rfi
16
16
17
17
18
- def argmax_nd (x : np .ndarray , axes : List [int ], * , max_rep = np .inf , max_rep_strict = None ):
18
+ def argmax_nd (
19
+ x : np .ndarray ,
20
+ axes : List [int ],
21
+ * ,
22
+ max_rep : Union [int , float ] = np .inf ,
23
+ max_rep_strict : Optional [bool ] = None ,
24
+ ):
19
25
"""Return the indices of the maximum value along the given axes.
20
26
21
27
Args:
@@ -37,7 +43,7 @@ def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=
37
43
if max_rep <= 0 :
38
44
raise ValueError ("max_rep must be greater than 0." )
39
45
if max_rep_strict is None and not np .isinf (max_rep ):
40
- raise ValueError ("if max_rep_strict is not set if max_rep must be infinite." )
46
+ raise ValueError ("if max_rep_strict is not set, then max_rep must be infinite." )
41
47
# Make it so the axes we want to find the maximum along are the first ones...
42
48
perm = list (range (len (x .shape )))
43
49
for axis in reversed (axes ):
@@ -94,14 +100,14 @@ class LayerNMF:
94
100
95
101
def __init__ (
96
102
self ,
97
- model ,
98
- layer_name ,
99
- model_inputs_preprocess ,
100
- model_inputs_full = None ,
103
+ model : th . nn . Module ,
104
+ layer_name : str ,
105
+ model_inputs_preprocess : th . Tensor ,
106
+ model_inputs_full : Optional [ th . Tensor ] = None ,
101
107
features : Optional [int ] = 10 ,
102
108
* ,
103
109
attr_layer_name : Optional [str ] = None ,
104
- attr_opts = {"integrate_steps" : 10 },
110
+ attr_opts : Dict [ str , int ] = {"integrate_steps" : 10 },
105
111
activation_fn : Optional [str ] = None ,
106
112
):
107
113
"""Use Non-negative matrix factorization dimensionality reduction to then do
@@ -231,9 +237,9 @@ def vis_traditional(
231
237
self ,
232
238
feature_list = None ,
233
239
* ,
234
- transforms = [transform .jitter (2 )],
235
- l2_coeff = 0.0 ,
236
- l2_layer_name = None ,
240
+ transforms : List [ Callable [[ th . Tensor ], th . Tensor ]] = [transform .jitter (2 )],
241
+ l2_coeff : float = 0.0 ,
242
+ l2_layer_name : Optional [ str ] = None ,
237
243
):
238
244
if feature_list is None :
239
245
# Feature dim is at index 1
@@ -329,7 +335,14 @@ def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
329
335
slice_w = slice (int (round (left_w )), int (round (right_w )))
330
336
return self .padded_obses [obs_index , :, slice_h , slice_w ]
331
337
332
- def vis_dataset (self , feature , * , subdiv_mult = 1 , expand_mult = 1 , top_frac = 0.1 ):
338
+ def vis_dataset (
339
+ self ,
340
+ feature : Union [int , List [int ]],
341
+ * ,
342
+ subdiv_mult = 1 ,
343
+ expand_mult = 1 ,
344
+ top_frac : float = 0.1 ,
345
+ ):
333
346
"""Visualize a dataset of patches that maximize a given feature.
334
347
335
348
Args:
@@ -394,7 +407,12 @@ def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
394
407
)
395
408
396
409
def vis_dataset_thumbnail (
397
- self , feature , * , num_mult = 1 , expand_mult = 1 , max_rep = None
410
+ self ,
411
+ feature : Union [int , List [int ]],
412
+ * ,
413
+ num_mult : int = 1 ,
414
+ expand_mult : int = 1 ,
415
+ max_rep : Optional [Union [int , float ]] = None ,
398
416
):
399
417
"""Visualize a dataset of patches that maximize a given feature.
400
418
0 commit comments