Skip to content

Commit ae85bc8

Browse files
committed
Reverted causal model to accept input_function & output_function parameters
1 parent 37782e6 commit ae85bc8

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

pyvene/data_generators/causal_model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,17 @@ def generate_factual_dataset(
309309
sampler=None,
310310
filter=None,
311311
device="cpu",
312+
input_function=None,
313+
output_function=None,
312314
return_tensors=True,
313315
):
314316
if sampler is None:
315317
sampler = self.sample_input
318+
319+
if input_function is None:
320+
input_function = self.input_to_tensor
321+
if output_function is None:
322+
output_function = self.output_to_tensor
316323

317324
examples = []
318325
while len(examples) < size:
@@ -321,8 +328,8 @@ def generate_factual_dataset(
321328
if filter is None or filter(input):
322329
output = self.run_forward(input)
323330
if return_tensors:
324-
example['input_ids'] = self.input_to_tensor(input).to(device)
325-
example['labels'] = self.output_to_tensor(output).to(device)
331+
example['input_ids'] = input_function(input).to(device)
332+
example['labels'] = output_function(output).to(device)
326333
else:
327334
example['input_ids'] = input
328335
example['labels'] = output
@@ -339,8 +346,15 @@ def generate_counterfactual_dataset(
339346
intervention_sampler=None,
340347
filter=None,
341348
device="cpu",
349+
input_function=None,
350+
output_function=None,
342351
return_tensors=True,
343352
):
353+
if input_function is None:
354+
input_function = self.input_to_tensor
355+
if output_function is None:
356+
output_function = self.output_to_tensor
357+
344358
maxlength = len(
345359
[
346360
var

0 commit comments

Comments
 (0)