Skip to content

Commit be7eeb6

Browse files
committed
tests: equation.py
1 parent f51ee43 commit be7eeb6

File tree

1 file changed

+185
-174
lines changed

1 file changed

+185
-174
lines changed

tests/test_equation.py

Lines changed: 185 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -14,181 +14,192 @@
1414
##############################################################################
1515
"""Tests for refinableobj module."""
1616

17-
import unittest
17+
import pytest
1818

1919
import diffpy.srfit.equation.literals as literals
2020
from diffpy.srfit.equation import Equation
2121

22-
from .utils import _makeArgs, noObserversInGlobalBuilders
23-
24-
25-
class TestEquation(unittest.TestCase):
26-
27-
def testSimpleFunction(self):
28-
"""Test a simple function."""
29-
30-
# Make some variables
31-
v1, v2, v3, v4, c = _makeArgs(5)
32-
c.name = "c"
33-
c.const = True
34-
35-
# Make some operations
36-
mult = literals.MultiplicationOperator()
37-
root = mult2 = literals.MultiplicationOperator()
38-
plus = literals.AdditionOperator()
39-
minus = literals.SubtractionOperator()
40-
41-
# Create the equation c*(v1+v3)*(v4-v2)
42-
plus.addLiteral(v1)
43-
plus.addLiteral(v3)
44-
minus.addLiteral(v4)
45-
minus.addLiteral(v2)
46-
mult.addLiteral(plus)
47-
mult.addLiteral(minus)
48-
mult2.addLiteral(mult)
49-
mult2.addLiteral(c)
50-
51-
# Set the values of the variables.
52-
# The equation should evaluate to 2.5*(1+3)*(4-2) = 20
53-
v1.setValue(1)
54-
v2.setValue(2)
55-
v3.setValue(3)
56-
v4.setValue(4)
57-
c.setValue(2.5)
58-
59-
# Make an equation and test
60-
eq = Equation("eq", mult2)
61-
62-
self.assertTrue(eq._value is None)
63-
args = eq.args
64-
self.assertTrue(v1 in args)
65-
self.assertTrue(v2 in args)
66-
self.assertTrue(v3 in args)
67-
self.assertTrue(v4 in args)
68-
self.assertTrue(c not in args)
69-
self.assertTrue(root is eq.root)
70-
71-
self.assertTrue(v1 is eq.v1)
72-
self.assertTrue(v2 is eq.v2)
73-
self.assertTrue(v3 is eq.v3)
74-
self.assertTrue(v4 is eq.v4)
75-
76-
self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2)
77-
self.assertEqual(20, eq.getValue()) # same as above
78-
self.assertEqual(20, eq.value) # same as above
79-
self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2)
80-
self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0)
81-
self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0)
82-
self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0)
83-
84-
# Try some swapping
85-
eq.swap(v4, v1)
86-
self.assertTrue(eq._value is None)
87-
self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0)
88-
args = eq.args
89-
self.assertTrue(v4 not in args)
90-
91-
# Try to create a dependency loop
92-
self.assertRaises(ValueError, eq.swap, v1, eq.root)
93-
self.assertRaises(ValueError, eq.swap, v1, plus)
94-
self.assertRaises(ValueError, eq.swap, v1, minus)
95-
self.assertRaises(ValueError, eq.swap, v1, mult)
96-
self.assertRaises(ValueError, eq.swap, v1, root)
97-
98-
# Swap the root
99-
eq.swap(eq.root, v1)
100-
self.assertTrue(eq._value is None)
101-
self.assertEqual(v1.value, eq())
102-
103-
self.assertTrue(noObserversInGlobalBuilders())
104-
return
105-
106-
def testEmbeddedEquation(self):
107-
"""Test a simple function."""
108-
109-
# Make some variables
110-
v1, v2, v3, v4, c = _makeArgs(5)
111-
c.name = "c"
112-
c.const = True
113-
114-
# Make some operations
115-
mult = literals.MultiplicationOperator()
116-
mult2 = literals.MultiplicationOperator()
117-
plus = literals.AdditionOperator()
118-
minus = literals.SubtractionOperator()
119-
120-
# Create the equation c*(v1+v3)*(v4-v2)
121-
plus.addLiteral(v1)
122-
plus.addLiteral(v3)
123-
minus.addLiteral(v4)
124-
minus.addLiteral(v2)
125-
mult.addLiteral(plus)
126-
mult.addLiteral(minus)
127-
mult2.addLiteral(mult)
128-
mult2.addLiteral(c)
129-
130-
# Set the values of the variables.
131-
# The equation should evaluate to 2.5*(1+3)*(4-2) = 20
132-
v1.setValue(1)
133-
v2.setValue(2)
134-
v3.setValue(3)
135-
v4.setValue(4)
136-
c.setValue(2.5)
137-
138-
# Make an equation and test
139-
root = Equation("root", mult2)
140-
eq = Equation("eq", root)
141-
142-
self.assertTrue(eq._value is None)
143-
args = eq.args
144-
self.assertTrue(v1 in args)
145-
self.assertTrue(v2 in args)
146-
self.assertTrue(v3 in args)
147-
self.assertTrue(v4 in args)
148-
self.assertTrue(c not in args)
149-
self.assertTrue(root is eq.root)
150-
151-
self.assertTrue(v1 is eq.v1)
152-
self.assertTrue(v2 is eq.v2)
153-
self.assertTrue(v3 is eq.v3)
154-
self.assertTrue(v4 is eq.v4)
155-
156-
# Make sure the right messages get sent
157-
v1.value = 0
158-
self.assertTrue(root._value is None)
159-
self.assertTrue(eq._value is None)
160-
v1.value = 1
161-
162-
self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2)
163-
self.assertEqual(20, eq.getValue()) # same as above
164-
self.assertEqual(20, eq.value) # same as above
165-
self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2)
166-
self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0)
167-
self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0)
168-
self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0)
169-
170-
# Try some swapping.
171-
eq.swap(v4, v1)
172-
self.assertTrue(eq._value is None)
173-
self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0)
174-
args = eq.args
175-
self.assertTrue(v4 not in args)
176-
177-
# Try to create a dependency loop
178-
self.assertRaises(ValueError, eq.swap, v1, eq.root)
179-
self.assertRaises(ValueError, eq.swap, v1, plus)
180-
self.assertRaises(ValueError, eq.swap, v1, minus)
181-
self.assertRaises(ValueError, eq.swap, v1, mult)
182-
self.assertRaises(ValueError, eq.swap, v1, root)
183-
184-
# Swap the root
185-
eq.swap(eq.root, v1)
186-
self.assertTrue(eq._value is None)
187-
self.assertEqual(v1.value, eq())
188-
189-
self.assertTrue(noObserversInGlobalBuilders())
190-
return
191-
192-
193-
if __name__ == "__main__":
194-
unittest.main()
22+
23+
def testSimpleFunction(make_args, noObserversInGlobalBuilders):
24+
"""Test a simple function."""
25+
26+
# Make some variables
27+
v1, v2, v3, v4, c = make_args(5)
28+
c.name = "c"
29+
c.const = True
30+
31+
# Make some operations
32+
mult = literals.MultiplicationOperator()
33+
root = mult2 = literals.MultiplicationOperator()
34+
plus = literals.AdditionOperator()
35+
minus = literals.SubtractionOperator()
36+
37+
# Create the equation c*(v1+v3)*(v4-v2)
38+
plus.addLiteral(v1)
39+
plus.addLiteral(v3)
40+
minus.addLiteral(v4)
41+
minus.addLiteral(v2)
42+
mult.addLiteral(plus)
43+
mult.addLiteral(minus)
44+
mult2.addLiteral(mult)
45+
mult2.addLiteral(c)
46+
47+
# Set the values of the variables.
48+
# The equation should evaluate to 2.5*(1+3)*(4-2) = 20
49+
v1.setValue(1)
50+
v2.setValue(2)
51+
v3.setValue(3)
52+
v4.setValue(4)
53+
c.setValue(2.5)
54+
55+
# Make an equation and test
56+
eq = Equation("eq", mult2)
57+
58+
assert eq._value is None
59+
args = eq.args
60+
assert v1 in args
61+
assert v2 in args
62+
assert v3 in args
63+
assert v4 in args
64+
assert c not in args
65+
assert root is eq.root
66+
67+
assert v1 is eq.v1
68+
assert v2 is eq.v2
69+
assert v3 is eq.v3
70+
assert v4 is eq.v4
71+
72+
assert 20 == eq() # 20 = 2.5*(1+3)*(4-2)
73+
assert 20 == eq.getValue() # same as above
74+
assert 20 == eq.value # same as above
75+
assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2)
76+
assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0)
77+
assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0)
78+
assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0)
79+
80+
# Try some swapping
81+
eq.swap(v4, v1)
82+
assert eq._value is None
83+
assert 15 == eq() # 15 = 2.5*(2+1)*(2-0)
84+
args = eq.args
85+
assert v4 not in args
86+
87+
# Try to create a dependency loop
88+
with pytest.raises(ValueError):
89+
eq.swap(v1, eq.root)
90+
91+
with pytest.raises(ValueError):
92+
eq.swap(v1, plus)
93+
94+
with pytest.raises(ValueError):
95+
eq.swap(v1, minus)
96+
97+
with pytest.raises(ValueError):
98+
eq.swap(v1, mult)
99+
100+
with pytest.raises(ValueError):
101+
eq.swap(v1, root)
102+
103+
# Swap the root
104+
eq.swap(eq.root, v1)
105+
assert eq._value is None
106+
assert v1.value, eq()
107+
108+
assert noObserversInGlobalBuilders
109+
return
110+
111+
112+
def testEmbeddedEquation(make_args, noObserversInGlobalBuilders):
113+
"""Test a simple function."""
114+
115+
# Make some variables
116+
v1, v2, v3, v4, c = make_args(5)
117+
c.name = "c"
118+
c.const = True
119+
120+
# Make some operations
121+
mult = literals.MultiplicationOperator()
122+
mult2 = literals.MultiplicationOperator()
123+
plus = literals.AdditionOperator()
124+
minus = literals.SubtractionOperator()
125+
126+
# Create the equation c*(v1+v3)*(v4-v2)
127+
plus.addLiteral(v1)
128+
plus.addLiteral(v3)
129+
minus.addLiteral(v4)
130+
minus.addLiteral(v2)
131+
mult.addLiteral(plus)
132+
mult.addLiteral(minus)
133+
mult2.addLiteral(mult)
134+
mult2.addLiteral(c)
135+
136+
# Set the values of the variables.
137+
# The equation should evaluate to 2.5*(1+3)*(4-2) = 20
138+
v1.setValue(1)
139+
v2.setValue(2)
140+
v3.setValue(3)
141+
v4.setValue(4)
142+
c.setValue(2.5)
143+
144+
# Make an equation and test
145+
root = Equation("root", mult2)
146+
eq = Equation("eq", root)
147+
148+
assert eq._value is None
149+
args = eq.args
150+
assert v1 in args
151+
assert v2 in args
152+
assert v3 in args
153+
assert v4 in args
154+
assert c not in args
155+
assert root is eq.root
156+
157+
assert v1 is eq.v1
158+
assert v2 is eq.v2
159+
assert v3 is eq.v3
160+
assert v4 is eq.v4
161+
162+
# Make sure the right messages get sent
163+
v1.value = 0
164+
assert root._value is None
165+
assert eq._value is None
166+
v1.value = 1
167+
168+
assert 20 == eq() # 20 = 2.5*(1+3)*(4-2)
169+
assert 20 == eq.getValue() # same as above
170+
assert 20 == eq.value # same as above
171+
assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2)
172+
assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0)
173+
assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0)
174+
assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0)
175+
176+
# Try some swapping.
177+
eq.swap(v4, v1)
178+
assert eq._value is None
179+
assert 15 == eq() # 15 = 2.5*(2+1)*(2-0)
180+
args = eq.args
181+
assert v4 not in args
182+
183+
# Try to create a dependency loop
184+
with pytest.raises(ValueError):
185+
eq.swap(v1, eq.root)
186+
187+
with pytest.raises(ValueError):
188+
eq.swap(v1, plus)
189+
190+
with pytest.raises(ValueError):
191+
eq.swap(v1, minus)
192+
193+
with pytest.raises(ValueError):
194+
eq.swap(v1, mult)
195+
196+
with pytest.raises(ValueError):
197+
eq.swap(v1, root)
198+
199+
# Swap the root
200+
eq.swap(eq.root, v1)
201+
assert eq._value is None
202+
assert v1.value == eq()
203+
204+
assert noObserversInGlobalBuilders
205+
return

0 commit comments

Comments
 (0)