1
1
"""Port of lucid.scratch.rl_util to PyTorch. APL2.0 licensed."""
2
2
from functools import reduce
3
3
import logging
4
- from typing import List , Optional
4
+ from typing import Callable , Dict , List , Optional , Union
5
5
6
6
from lucent .optvis .objectives import handle_batch , wrap_objective
7
7
import lucent .optvis .param as param
16
16
import reward_preprocessing .vis .objectives as objectives_rfi
17
17
18
18
19
- def argmax_nd (x : np .ndarray , axes : List [int ], * , max_rep = np .inf , max_rep_strict = None ):
19
+ def argmax_nd (
20
+ x : np .ndarray ,
21
+ axes : List [int ],
22
+ * ,
23
+ max_rep : Union [int , float ] = np .inf ,
24
+ max_rep_strict : Optional [bool ] = None ,
25
+ ):
20
26
"""Return the indices of the maximum value along the given axes.
21
27
22
28
Args:
@@ -38,7 +44,7 @@ def argmax_nd(x: np.ndarray, axes: List[int], *, max_rep=np.inf, max_rep_strict=
38
44
if max_rep <= 0 :
39
45
raise ValueError ("max_rep must be greater than 0." )
40
46
if max_rep_strict is None and not np .isinf (max_rep ):
41
- raise ValueError ("if max_rep_strict is not set if max_rep must be infinite." )
47
+ raise ValueError ("if max_rep_strict is not set, then max_rep must be infinite." )
42
48
# Make it so the axes we want to find the maximum along are the first ones...
43
49
perm = list (range (len (x .shape )))
44
50
for axis in reversed (axes ):
@@ -106,14 +112,14 @@ class LayerNMF:
106
112
107
113
def __init__ (
108
114
self ,
109
- model ,
110
- layer_name ,
111
- model_inputs_preprocess ,
112
- model_inputs_full = None ,
115
+ model : th . nn . Module ,
116
+ layer_name : str ,
117
+ model_inputs_preprocess : th . Tensor ,
118
+ model_inputs_full : Optional [ th . Tensor ] = None ,
113
119
features : Optional [int ] = 10 ,
114
120
* ,
115
121
attr_layer_name : Optional [str ] = None ,
116
- attr_opts = {"integrate_steps" : 10 },
122
+ attr_opts : Dict [ str , int ] = {"integrate_steps" : 10 },
117
123
activation_fn : Optional [str ] = None ,
118
124
):
119
125
"""Use Non-negative matrix factorization dimensionality reduction to then do
@@ -244,9 +250,9 @@ def vis_traditional(
244
250
self ,
245
251
feature_list = None ,
246
252
* ,
247
- transforms = [transform .jitter (2 )],
248
- l2_coeff = 0.0 ,
249
- l2_layer_name = None ,
253
+ transforms : List [ Callable [[ th . Tensor ], th . Tensor ]] = [transform .jitter (2 )],
254
+ l2_coeff : float = 0.0 ,
255
+ l2_layer_name : Optional [ str ] = None ,
250
256
) -> np .ndarray :
251
257
if feature_list is None :
252
258
# Feature dim is at index 1
@@ -341,7 +347,14 @@ def get_patch(self, obs_index, pos_h, pos_w, *, expand_mult=1):
341
347
slice_w = slice (int (round (left_w )), int (round (right_w )))
342
348
return self .padded_obses [obs_index , :, slice_h , slice_w ]
343
349
344
- def vis_dataset (self , feature , * , subdiv_mult = 1 , expand_mult = 1 , top_frac = 0.1 ):
350
+ def vis_dataset (
351
+ self ,
352
+ feature : Union [int , List [int ]],
353
+ * ,
354
+ subdiv_mult = 1 ,
355
+ expand_mult = 1 ,
356
+ top_frac : float = 0.1 ,
357
+ ):
345
358
"""Visualize a dataset of patches that maximize a given feature.
346
359
347
360
Args:
@@ -406,7 +419,12 @@ def vis_dataset(self, feature, *, subdiv_mult=1, expand_mult=1, top_frac=0.1):
406
419
)
407
420
408
421
def vis_dataset_thumbnail (
409
- self , feature , * , num_mult = 1 , expand_mult = 1 , max_rep = None
422
+ self ,
423
+ feature : Union [int , List [int ]],
424
+ * ,
425
+ num_mult : int = 1 ,
426
+ expand_mult : int = 1 ,
427
+ max_rep : Optional [Union [int , float ]] = None ,
410
428
):
411
429
"""Visualize a dataset of patches that maximize a given feature.
412
430
0 commit comments