File tree Expand file tree Collapse file tree 2 files changed +17
-3
lines changed Expand file tree Collapse file tree 2 files changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -309,10 +309,17 @@ def generate_factual_dataset(
309
309
sampler = None ,
310
310
filter = None ,
311
311
device = "cpu" ,
312
+ input_function = None ,
313
+ output_function = None ,
312
314
return_tensors = True ,
313
315
):
314
316
if sampler is None :
315
317
sampler = self .sample_input
318
+
319
+ if input_function is None :
320
+ input_function = self .input_to_tensor
321
+ if output_function is None :
322
+ output_function = self .output_to_tensor
316
323
317
324
examples = []
318
325
while len (examples ) < size :
@@ -321,8 +328,8 @@ def generate_factual_dataset(
321
328
if filter is None or filter (input ):
322
329
output = self .run_forward (input )
323
330
if return_tensors :
324
- example ['input_ids' ] = self . input_to_tensor (input ).to (device )
325
- example ['labels' ] = self . output_to_tensor (output ).to (device )
331
+ example ['input_ids' ] = input_function (input ).to (device )
332
+ example ['labels' ] = output_function (output ).to (device )
326
333
else :
327
334
example ['input_ids' ] = input
328
335
example ['labels' ] = output
@@ -339,8 +346,15 @@ def generate_counterfactual_dataset(
339
346
intervention_sampler = None ,
340
347
filter = None ,
341
348
device = "cpu" ,
349
+ input_function = None ,
350
+ output_function = None ,
342
351
return_tensors = True ,
343
352
):
353
+ if input_function is None :
354
+ input_function = self .input_to_tensor
355
+ if output_function is None :
356
+ output_function = self .output_to_tensor
357
+
344
358
maxlength = len (
345
359
[
346
360
var
Original file line number Diff line number Diff line change 10
10
11
11
setup (
12
12
name = "pyvene" ,
13
- version = "0.0.8 " ,
13
+ version = "0.0.9dev " ,
14
14
description = "Use Activation Intervention to Interpret Causal Mechanism of Model" ,
15
15
long_description = long_description ,
16
16
long_description_content_type = 'text/markdown' ,
You can’t perform that action at this time.
0 commit comments