Skip to content

Commit faf581b

Browse files
authored
Merge pull request #137 from amirzur2023/main
[Minor] Revert CausalModel to accept input/output functions when generating factual/counterfactual datasets
2 parents 37782e6 + 4b90cac commit faf581b

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

pyvene/data_generators/causal_model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,17 @@ def generate_factual_dataset(
309309
sampler=None,
310310
filter=None,
311311
device="cpu",
312+
input_function=None,
313+
output_function=None,
312314
return_tensors=True,
313315
):
314316
if sampler is None:
315317
sampler = self.sample_input
318+
319+
if input_function is None:
320+
input_function = self.input_to_tensor
321+
if output_function is None:
322+
output_function = self.output_to_tensor
316323

317324
examples = []
318325
while len(examples) < size:
@@ -321,8 +328,8 @@ def generate_factual_dataset(
321328
if filter is None or filter(input):
322329
output = self.run_forward(input)
323330
if return_tensors:
324-
example['input_ids'] = self.input_to_tensor(input).to(device)
325-
example['labels'] = self.output_to_tensor(output).to(device)
331+
example['input_ids'] = input_function(input).to(device)
332+
example['labels'] = output_function(output).to(device)
326333
else:
327334
example['input_ids'] = input
328335
example['labels'] = output
@@ -339,8 +346,15 @@ def generate_counterfactual_dataset(
339346
intervention_sampler=None,
340347
filter=None,
341348
device="cpu",
349+
input_function=None,
350+
output_function=None,
342351
return_tensors=True,
343352
):
353+
if input_function is None:
354+
input_function = self.input_to_tensor
355+
if output_function is None:
356+
output_function = self.output_to_tensor
357+
344358
maxlength = len(
345359
[
346360
var

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="pyvene",
13-
version="0.0.8",
13+
version="0.0.9dev",
1414
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)