File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -309,10 +309,17 @@ def generate_factual_dataset(
309309 sampler = None ,
310310 filter = None ,
311311 device = "cpu" ,
312+ input_function = None ,
313+ output_function = None ,
312314 return_tensors = True ,
313315 ):
314316 if sampler is None :
315317 sampler = self .sample_input
318+
319+ if input_function is None :
320+ input_function = self .input_to_tensor
321+ if output_function is None :
322+ output_function = self .output_to_tensor
316323
317324 examples = []
318325 while len (examples ) < size :
@@ -321,8 +328,8 @@ def generate_factual_dataset(
321328 if filter is None or filter (input ):
322329 output = self .run_forward (input )
323330 if return_tensors :
324- example ['input_ids' ] = self . input_to_tensor (input ).to (device )
325- example ['labels' ] = self . output_to_tensor (output ).to (device )
331+ example ['input_ids' ] = input_function (input ).to (device )
332+ example ['labels' ] = output_function (output ).to (device )
326333 else :
327334 example ['input_ids' ] = input
328335 example ['labels' ] = output
@@ -339,8 +346,15 @@ def generate_counterfactual_dataset(
339346 intervention_sampler = None ,
340347 filter = None ,
341348 device = "cpu" ,
349+ input_function = None ,
350+ output_function = None ,
342351 return_tensors = True ,
343352 ):
353+ if input_function is None :
354+ input_function = self .input_to_tensor
355+ if output_function is None :
356+ output_function = self .output_to_tensor
357+
344358 maxlength = len (
345359 [
346360 var
Original file line number Diff line number Diff line change 1010
1111setup (
1212 name = "pyvene" ,
13- version = "0.0.8 " ,
13+ version = "0.0.9dev " ,
1414 description = "Use Activation Intervention to Interpret Causal Mechanism of Model" ,
1515 long_description = long_description ,
1616 long_description_content_type = 'text/markdown' ,
You can’t perform that action at this time.
0 commit comments