Skip to content

Commit 04f74e2

Browse files
committed
Fix traditional visualization
1 parent 5a79629 commit 04f74e2

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/reward_preprocessing/vis/reward_vis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,14 @@ def vis_traditional(
289289
for feature in feature_list
290290
]
291291
)
292+
292293
if l2_coeff != 0.0:
293294
if l2_layer_name is None:
294295
raise ValueError(
295296
"l2_layer_name must be specified if l2_coeff is non-zero"
296297
)
297298
obj -= l2_objective(l2_layer_name, l2_coeff)
299+
298300
input_shape = tuple(self.model_inputs_preprocess.shape[1:])
299301

300302
if param_f is None:
@@ -305,6 +307,7 @@ def param_f():
305307
h=input_shape[1],
306308
w=input_shape[2],
307309
batch=len(feature_list),
310+
sd=3,
308311
)
309312

310313
logging.info(f"Performing vis_traditional with transforms: {transforms}")

0 commit comments

Comments
 (0)