File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -309,10 +309,17 @@ def generate_factual_dataset(
309
309
sampler = None ,
310
310
filter = None ,
311
311
device = "cpu" ,
312
+ input_function = None ,
313
+ output_function = None ,
312
314
return_tensors = True ,
313
315
):
314
316
if sampler is None :
315
317
sampler = self .sample_input
318
+
319
+ if input_function is None :
320
+ input_function = self .input_to_tensor
321
+ if output_function is None :
322
+ output_function = self .output_to_tensor
316
323
317
324
examples = []
318
325
while len (examples ) < size :
@@ -321,8 +328,8 @@ def generate_factual_dataset(
321
328
if filter is None or filter (input ):
322
329
output = self .run_forward (input )
323
330
if return_tensors :
324
- example ['input_ids' ] = self . input_to_tensor (input ).to (device )
325
- example ['labels' ] = self . output_to_tensor (output ).to (device )
331
+ example ['input_ids' ] = input_function (input ).to (device )
332
+ example ['labels' ] = output_function (output ).to (device )
326
333
else :
327
334
example ['input_ids' ] = input
328
335
example ['labels' ] = output
@@ -339,8 +346,15 @@ def generate_counterfactual_dataset(
339
346
intervention_sampler = None ,
340
347
filter = None ,
341
348
device = "cpu" ,
349
+ input_function = None ,
350
+ output_function = None ,
342
351
return_tensors = True ,
343
352
):
353
+ if input_function is None :
354
+ input_function = self .input_to_tensor
355
+ if output_function is None :
356
+ output_function = self .output_to_tensor
357
+
344
358
maxlength = len (
345
359
[
346
360
var
You can’t perform that action at this time.
0 commit comments