Skip to content

Commit a2b8e3a

Browse files
authored
Merge pull request #131 from amirzur2023/main
Test Cases for Causal Model, MQNLI Intro Notebook
2 parents 89b8e4f + 3fb81b1 commit a2b8e3a

File tree

9 files changed

+2317
-209
lines changed

9 files changed

+2317
-209
lines changed

pyvene/data_generators/causal_model.py

+63-42
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def generate_timesteps(self):
107107
step += 1
108108
for var in self.variables:
109109
assert var in timesteps
110-
return timesteps, step - 1
110+
# return all timesteps and timestep of root
111+
return timesteps, step - 2
111112

112113
def marginalize(self, target):
113114
pass
@@ -148,9 +149,12 @@ def find_live_paths(self, intervention):
148149
del paths[1]
149150
return paths
150151

151-
def print_setting(self, total_setting):
152+
def print_setting(self, total_setting, display=None):
153+
labeler = lambda var: var + ": " + str(total_setting[var]) \
154+
if display is None or display[var] \
155+
else var
152156
relabeler = {
153-
var: var + ": " + str(total_setting[var]) for var in self.variables
157+
var: labeler(var) for var in self.variables
154158
}
155159
G = nx.DiGraph()
156160
G.add_edges_from(
@@ -227,21 +231,27 @@ def sample_input(self, mandatory=None):
227231
total = self.run_forward(intervention=input)
228232
return input
229233

230-
def sample_input_tree_balanced(self, output_var=None):
234+
def sample_input_tree_balanced(self, output_var=None, output_var_value=None):
231235
assert output_var is not None or len(self.outputs) == 1
232236
if output_var is None:
233237
output_var = self.outputs[0]
238+
if output_var_value is None:
239+
output_var_value = random.choice(self.values[output_var])
234240

235241
def create_input(var, value, input={}):
236242
parent_values = random.choice(self.equiv_classes[var][value])
237243
for parent in parent_values:
238244
if parent in self.inputs:
239245
input[parent] = parent_values[parent]
240246
else:
241-
create_input(parent, random.choice(self.values[parent]), input)
247+
create_input(parent, parent_values[parent], input)
242248
return input
243249

244-
return create_input(output_var, random.choice(self.values[output_var]))
250+
input_setting = create_input(output_var, output_var_value)
251+
for input_var in self.inputs:
252+
if input_var not in input_setting:
253+
input_setting[input_var] = random.choice(self.values[input_var])
254+
return input_setting
245255

246256
def get_path_maxlen_filter(self, lengths):
247257
def check_path(total_setting):
@@ -299,24 +309,26 @@ def generate_factual_dataset(
299309
sampler=None,
300310
filter=None,
301311
device="cpu",
302-
inputFunction=None,
303-
outputFunction=None
312+
return_tensors=True,
304313
):
305-
if inputFunction is None:
306-
inputFunction = self.input_to_tensor
307-
if outputFunction is None:
308-
outputFunction = self.output_to_tensor
309314
if sampler is None:
310315
sampler = self.sample_input
311-
X, y = [], []
312-
count = 0
313-
while count < size:
316+
317+
examples = []
318+
while len(examples) < size:
319+
example = dict()
314320
input = sampler()
315321
if filter is None or filter(input):
316-
X.append(inputFunction(input))
317-
y.append(outputFunction(self.run_forward(input)))
318-
count += 1
319-
return torch.stack(X).to(device), torch.stack(y).to(device)
322+
output = self.run_forward(input)
323+
if return_tensors:
324+
example['input_ids'] = self.input_to_tensor(input).to(device)
325+
example['labels'] = self.output_to_tensor(output).to(device)
326+
else:
327+
example['input_ids'] = input
328+
example['labels'] = output
329+
examples.append(example)
330+
331+
return examples
320332

321333
def generate_counterfactual_dataset(
322334
self,
@@ -327,8 +339,7 @@ def generate_counterfactual_dataset(
327339
intervention_sampler=None,
328340
filter=None,
329341
device="cpu",
330-
inputFunction=None,
331-
outputFunction=None
342+
return_tensors=True,
332343
):
333344
maxlength = len(
334345
[
@@ -337,17 +348,12 @@ def generate_counterfactual_dataset(
337348
if var not in self.inputs and var not in self.outputs
338349
]
339350
)
340-
if inputFunction is None:
341-
inputFunction = self.input_to_tensor
342-
if outputFunction is None:
343-
outputFunction = self.output_to_tensor
344351
if sampler is None:
345352
sampler = self.sample_input
346353
if intervention_sampler is None:
347354
intervention_sampler = self.sample_intervention
348355
examples = []
349-
count = 0
350-
while count < size:
356+
while len(examples) < size:
351357
intervention = intervention_sampler()
352358
if filter is None or filter(intervention):
353359
for _ in range(batch_size):
@@ -358,24 +364,39 @@ def generate_counterfactual_dataset(
358364
for var in self.variables:
359365
if var not in intervention:
360366
continue
361-
source = sampler()
362-
sources.append(inputFunction(source))
367+
# sample input to match sampled intervention value
368+
source = sampler(output_var=var, output_var_value=intervention[var])
369+
if return_tensors:
370+
sources.append(self.input_to_tensor(source))
371+
else:
372+
sources.append(source)
363373
source_dic[var] = source
364374
for _ in range(maxlength - len(sources)):
365-
sources.append(torch.zeros(self.input_to_tensor(sampler()).shape))
366-
example["labels"] = outputFunction(
367-
self.run_interchange(base, source_dic)
368-
).to(device)
369-
example["base_labels"] = outputFunction(
370-
self.run_forward(base)
371-
).to(device)
372-
example["input_ids"] = inputFunction(base).to(device)
373-
example["source_input_ids"] = torch.stack(sources).to(device)
374-
example["intervention_id"] = torch.tensor(
375-
[intervention_id(intervention)]
376-
).to(device)
375+
if return_tensors:
376+
sources.append(torch.zeros(self.input_to_tensor(base).shape))
377+
else:
378+
sources.append({})
379+
380+
if return_tensors:
381+
example["labels"] = self.output_to_tensor(
382+
self.run_interchange(base, source_dic)
383+
).to(device)
384+
example["base_labels"] = self.output_to_tensor(
385+
self.run_forward(base)
386+
).to(device)
387+
example["input_ids"] = self.input_to_tensor(base).to(device)
388+
example["source_input_ids"] = torch.stack(sources).to(device)
389+
example["intervention_id"] = torch.tensor(
390+
[intervention_id(intervention)]
391+
).to(device)
392+
else:
393+
example['labels'] = self.run_interchange(base, source_dic)
394+
example['base_labels'] = self.run_forward(base)
395+
example['input_ids'] = base
396+
example['source_input_ids'] = sources
397+
example['intervention_id'] = [intervention_id(intervention)]
398+
377399
examples.append(example)
378-
count += 1
379400
return examples
380401

381402

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import unittest
2+
import random
3+
import torch
4+
from pyvene import CausalModel
5+
random.seed(42)
6+
7+
8+
class CasualModelTestCase(unittest.TestCase):
9+
@classmethod
10+
def setUpClass(self):
11+
print("=== Test Suite: CausalModelTestCase ===")
12+
self.variables = ['A', 'B', 'C']
13+
self.values = {
14+
'A': [False, True],
15+
'B': [False, True],
16+
'C': [False, True]
17+
}
18+
19+
self.parents = {
20+
'A': [],
21+
'B': [],
22+
'C': ['A', 'B']
23+
}
24+
25+
self.functions = {
26+
"A": lambda: True,
27+
"B": lambda: True,
28+
"C": lambda a, b: a and b
29+
}
30+
31+
self.causal_model = CausalModel(
32+
self.variables,
33+
self.values,
34+
self.parents,
35+
self.functions
36+
)
37+
38+
def test_initialization(self):
39+
inputs = ['A', 'B']
40+
outputs = ['C']
41+
timesteps = {
42+
'A': 0,
43+
'B': 0,
44+
'C': 1
45+
}
46+
equivalence_classes = {
47+
'C': {
48+
False: [
49+
{'A': False, 'B': False},
50+
{'A': False, 'B': True},
51+
{'A': True, 'B': False}
52+
],
53+
True: [
54+
{'A': True, 'B': True}
55+
]
56+
}
57+
}
58+
59+
self.assertEqual(set(self.causal_model.inputs), set(inputs))
60+
self.assertEqual(set(self.causal_model.outputs), set(outputs))
61+
self.assertEqual(self.causal_model.timesteps, timesteps)
62+
self.assertEqual(self.causal_model.equiv_classes, equivalence_classes)
63+
64+
def test_run_forward(self):
65+
# test run forward with default values (A and B set to True)
66+
self.assertEqual(
67+
self.causal_model.run_forward(),
68+
{'A': True, 'B': True, 'C': True}
69+
)
70+
71+
# test run forward on all possible input values
72+
for a in [False, True]:
73+
for b in [False, True]:
74+
input_setting = {
75+
'A': a,
76+
'B': b
77+
}
78+
output_setting = {
79+
'A': a,
80+
'B': b,
81+
'C': a and b
82+
}
83+
self.assertEqual(self.causal_model.run_forward(input_setting), output_setting)
84+
85+
# test run forward on fully specified setting
86+
output_setting = {'A': False, 'B': False, 'C': True}
87+
self.assertEqual(self.causal_model.run_forward(output_setting), output_setting)
88+
89+
def test_run_interchange(self):
90+
# interchange intervention on input
91+
base = {'A': True, 'B': False}
92+
source = {'A': False, 'B': True}
93+
self.assertEqual(self.causal_model.run_forward(base)['C'], False)
94+
self.assertEqual(self.causal_model.run_forward(source)['C'], False)
95+
self.assertEqual(
96+
self.causal_model.run_interchange(base, {'B': source})['C'],
97+
True
98+
)
99+
100+
# interchange intervention on output
101+
base = {'A': False, 'B': False}
102+
source = {'A': True, 'B': True}
103+
self.assertEqual(self.causal_model.run_forward(base)['C'], False)
104+
self.assertEqual(
105+
self.causal_model.run_interchange(base, {'B': source})['C'],
106+
False
107+
)
108+
self.assertEqual(
109+
self.causal_model.run_interchange(base, {'C': source})['C'],
110+
True
111+
)
112+
113+
def test_sample_input_tree_balanced(self):
114+
# NOTE: not quite sure how to test a function with random behavior
115+
# right now, fixing seed and assuming approximate behavior
116+
# (taking balanced to be less than 30-70 split)
117+
118+
K = 100
119+
# test sampling by output value
120+
outputs = []
121+
for _ in range(K):
122+
sample = self.causal_model.sample_input_tree_balanced()
123+
output = self.causal_model.run_forward(sample)
124+
outputs.append(output['C'])
125+
self.assertGreaterEqual(sum(outputs), 30)
126+
self.assertLessEqual(sum(outputs), 70)
127+
128+
# test sampling by input value
129+
inputs = []
130+
for _ in range(K):
131+
sample = self.causal_model.sample_input_tree_balanced()
132+
inputs.append(sample['A'])
133+
self.assertGreaterEqual(sum(outputs), 30)
134+
self.assertLessEqual(sum(outputs), 70)
135+
136+
def test_generate_factual_dataset(self):
137+
def sampler():
138+
return {'A': False, 'B': False}
139+
140+
size = 4
141+
factual_dataset = self.causal_model.generate_factual_dataset(
142+
size=size,
143+
sampler=sampler,
144+
return_tensors=False
145+
)
146+
self.assertEqual(len(factual_dataset), size)
147+
148+
self.assertEqual(factual_dataset[0]['input_ids'], {'A': False, 'B': False})
149+
self.assertEqual(factual_dataset[0]['labels']['C'], False)
150+
151+
factual_dataset_tensors = self.causal_model.generate_factual_dataset(
152+
size=size,
153+
sampler=sampler,
154+
return_tensors=True
155+
)
156+
self.assertEqual(len(factual_dataset_tensors), size)
157+
X = torch.stack([example['input_ids'] for example in factual_dataset_tensors])
158+
y = torch.stack([example['labels'] for example in factual_dataset_tensors])
159+
self.assertEqual(X.shape, (size, 2))
160+
self.assertEqual(y.shape, (size, 1))
161+
self.assertTrue(torch.equal(X[0], torch.tensor([0., 0.])))
162+
self.assertTrue(torch.equal(y[0], torch.tensor([0.])))
163+
164+
def test_generate_counterfactual_dataset(self):
165+
def sampler(*args, **kwargs):
166+
if kwargs.get('output_var', None):
167+
return {'A': True, 'B': True}
168+
169+
return {'A': True, 'B': False}
170+
171+
def intervention_sampler(*args, **kwargs):
172+
return {'B': True}
173+
174+
def intervention_id(*args, **kwargs):
175+
return 0
176+
177+
size = 4
178+
counterfactual_dataset = self.causal_model.generate_counterfactual_dataset(
179+
size=size,
180+
batch_size=1,
181+
intervention_id=intervention_id,
182+
sampler=sampler,
183+
intervention_sampler=intervention_sampler,
184+
return_tensors=False
185+
)
186+
self.assertEqual(len(counterfactual_dataset), size)
187+
example = counterfactual_dataset[0]
188+
self.assertEqual(example['input_ids'], {'A': True, 'B': False})
189+
self.assertEqual(example['source_input_ids'][0]['B'], True)
190+
self.assertEqual(example['intervention_id'], [0])
191+
self.assertEqual(example['base_labels']['C'], False) # T and F
192+
self.assertEqual(example['labels']['C'], True) # T and T
193+
194+
195+
def suite():
196+
suite = unittest.TestSuite()
197+
suite.addTest(CasualModelTestCase("test_initialization"))
198+
suite.addTest(CasualModelTestCase("test_run_forward"))
199+
suite.addTest(CasualModelTestCase("test_run_interchange"))
200+
suite.addTest(CasualModelTestCase("test_sample_input_tree_balanced"))
201+
suite.addTest(CasualModelTestCase("test_generate_factual_dataset"))
202+
suite.addTest(CasualModelTestCase("test_generate_counterfactual_dataset"))
203+
return suite
204+
205+
206+
if __name__ == "__main__":
207+
runner = unittest.TextTestRunner()
208+
runner.run(suite())

0 commit comments

Comments
 (0)