|
14 | 14 | ##############################################################################
|
15 | 15 | """Tests for refinableobj module."""
|
16 | 16 |
|
17 |
| -import unittest |
| 17 | +import pytest |
18 | 18 |
|
19 | 19 | import diffpy.srfit.equation.literals as literals
|
20 | 20 | from diffpy.srfit.equation import Equation
|
21 | 21 |
|
22 |
| -from .utils import _makeArgs, noObserversInGlobalBuilders |
23 |
| - |
24 |
| - |
25 |
| -class TestEquation(unittest.TestCase): |
26 |
| - |
27 |
| - def testSimpleFunction(self): |
28 |
| - """Test a simple function.""" |
29 |
| - |
30 |
| - # Make some variables |
31 |
| - v1, v2, v3, v4, c = _makeArgs(5) |
32 |
| - c.name = "c" |
33 |
| - c.const = True |
34 |
| - |
35 |
| - # Make some operations |
36 |
| - mult = literals.MultiplicationOperator() |
37 |
| - root = mult2 = literals.MultiplicationOperator() |
38 |
| - plus = literals.AdditionOperator() |
39 |
| - minus = literals.SubtractionOperator() |
40 |
| - |
41 |
| - # Create the equation c*(v1+v3)*(v4-v2) |
42 |
| - plus.addLiteral(v1) |
43 |
| - plus.addLiteral(v3) |
44 |
| - minus.addLiteral(v4) |
45 |
| - minus.addLiteral(v2) |
46 |
| - mult.addLiteral(plus) |
47 |
| - mult.addLiteral(minus) |
48 |
| - mult2.addLiteral(mult) |
49 |
| - mult2.addLiteral(c) |
50 |
| - |
51 |
| - # Set the values of the variables. |
52 |
| - # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 |
53 |
| - v1.setValue(1) |
54 |
| - v2.setValue(2) |
55 |
| - v3.setValue(3) |
56 |
| - v4.setValue(4) |
57 |
| - c.setValue(2.5) |
58 |
| - |
59 |
| - # Make an equation and test |
60 |
| - eq = Equation("eq", mult2) |
61 |
| - |
62 |
| - self.assertTrue(eq._value is None) |
63 |
| - args = eq.args |
64 |
| - self.assertTrue(v1 in args) |
65 |
| - self.assertTrue(v2 in args) |
66 |
| - self.assertTrue(v3 in args) |
67 |
| - self.assertTrue(v4 in args) |
68 |
| - self.assertTrue(c not in args) |
69 |
| - self.assertTrue(root is eq.root) |
70 |
| - |
71 |
| - self.assertTrue(v1 is eq.v1) |
72 |
| - self.assertTrue(v2 is eq.v2) |
73 |
| - self.assertTrue(v3 is eq.v3) |
74 |
| - self.assertTrue(v4 is eq.v4) |
75 |
| - |
76 |
| - self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2) |
77 |
| - self.assertEqual(20, eq.getValue()) # same as above |
78 |
| - self.assertEqual(20, eq.value) # same as above |
79 |
| - self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2) |
80 |
| - self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0) |
81 |
| - self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0) |
82 |
| - self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0) |
83 |
| - |
84 |
| - # Try some swapping |
85 |
| - eq.swap(v4, v1) |
86 |
| - self.assertTrue(eq._value is None) |
87 |
| - self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0) |
88 |
| - args = eq.args |
89 |
| - self.assertTrue(v4 not in args) |
90 |
| - |
91 |
| - # Try to create a dependency loop |
92 |
| - self.assertRaises(ValueError, eq.swap, v1, eq.root) |
93 |
| - self.assertRaises(ValueError, eq.swap, v1, plus) |
94 |
| - self.assertRaises(ValueError, eq.swap, v1, minus) |
95 |
| - self.assertRaises(ValueError, eq.swap, v1, mult) |
96 |
| - self.assertRaises(ValueError, eq.swap, v1, root) |
97 |
| - |
98 |
| - # Swap the root |
99 |
| - eq.swap(eq.root, v1) |
100 |
| - self.assertTrue(eq._value is None) |
101 |
| - self.assertEqual(v1.value, eq()) |
102 |
| - |
103 |
| - self.assertTrue(noObserversInGlobalBuilders()) |
104 |
| - return |
105 |
| - |
106 |
| - def testEmbeddedEquation(self): |
107 |
| - """Test a simple function.""" |
108 |
| - |
109 |
| - # Make some variables |
110 |
| - v1, v2, v3, v4, c = _makeArgs(5) |
111 |
| - c.name = "c" |
112 |
| - c.const = True |
113 |
| - |
114 |
| - # Make some operations |
115 |
| - mult = literals.MultiplicationOperator() |
116 |
| - mult2 = literals.MultiplicationOperator() |
117 |
| - plus = literals.AdditionOperator() |
118 |
| - minus = literals.SubtractionOperator() |
119 |
| - |
120 |
| - # Create the equation c*(v1+v3)*(v4-v2) |
121 |
| - plus.addLiteral(v1) |
122 |
| - plus.addLiteral(v3) |
123 |
| - minus.addLiteral(v4) |
124 |
| - minus.addLiteral(v2) |
125 |
| - mult.addLiteral(plus) |
126 |
| - mult.addLiteral(minus) |
127 |
| - mult2.addLiteral(mult) |
128 |
| - mult2.addLiteral(c) |
129 |
| - |
130 |
| - # Set the values of the variables. |
131 |
| - # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 |
132 |
| - v1.setValue(1) |
133 |
| - v2.setValue(2) |
134 |
| - v3.setValue(3) |
135 |
| - v4.setValue(4) |
136 |
| - c.setValue(2.5) |
137 |
| - |
138 |
| - # Make an equation and test |
139 |
| - root = Equation("root", mult2) |
140 |
| - eq = Equation("eq", root) |
141 |
| - |
142 |
| - self.assertTrue(eq._value is None) |
143 |
| - args = eq.args |
144 |
| - self.assertTrue(v1 in args) |
145 |
| - self.assertTrue(v2 in args) |
146 |
| - self.assertTrue(v3 in args) |
147 |
| - self.assertTrue(v4 in args) |
148 |
| - self.assertTrue(c not in args) |
149 |
| - self.assertTrue(root is eq.root) |
150 |
| - |
151 |
| - self.assertTrue(v1 is eq.v1) |
152 |
| - self.assertTrue(v2 is eq.v2) |
153 |
| - self.assertTrue(v3 is eq.v3) |
154 |
| - self.assertTrue(v4 is eq.v4) |
155 |
| - |
156 |
| - # Make sure the right messages get sent |
157 |
| - v1.value = 0 |
158 |
| - self.assertTrue(root._value is None) |
159 |
| - self.assertTrue(eq._value is None) |
160 |
| - v1.value = 1 |
161 |
| - |
162 |
| - self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2) |
163 |
| - self.assertEqual(20, eq.getValue()) # same as above |
164 |
| - self.assertEqual(20, eq.value) # same as above |
165 |
| - self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2) |
166 |
| - self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0) |
167 |
| - self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0) |
168 |
| - self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0) |
169 |
| - |
170 |
| - # Try some swapping. |
171 |
| - eq.swap(v4, v1) |
172 |
| - self.assertTrue(eq._value is None) |
173 |
| - self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0) |
174 |
| - args = eq.args |
175 |
| - self.assertTrue(v4 not in args) |
176 |
| - |
177 |
| - # Try to create a dependency loop |
178 |
| - self.assertRaises(ValueError, eq.swap, v1, eq.root) |
179 |
| - self.assertRaises(ValueError, eq.swap, v1, plus) |
180 |
| - self.assertRaises(ValueError, eq.swap, v1, minus) |
181 |
| - self.assertRaises(ValueError, eq.swap, v1, mult) |
182 |
| - self.assertRaises(ValueError, eq.swap, v1, root) |
183 |
| - |
184 |
| - # Swap the root |
185 |
| - eq.swap(eq.root, v1) |
186 |
| - self.assertTrue(eq._value is None) |
187 |
| - self.assertEqual(v1.value, eq()) |
188 |
| - |
189 |
| - self.assertTrue(noObserversInGlobalBuilders()) |
190 |
| - return |
191 |
| - |
192 |
| - |
193 |
| -if __name__ == "__main__": |
194 |
| - unittest.main() |
| 22 | + |
| 23 | +def testSimpleFunction(make_args, noObserversInGlobalBuilders): |
| 24 | + """Test a simple function.""" |
| 25 | + |
| 26 | + # Make some variables |
| 27 | + v1, v2, v3, v4, c = make_args(5) |
| 28 | + c.name = "c" |
| 29 | + c.const = True |
| 30 | + |
| 31 | + # Make some operations |
| 32 | + mult = literals.MultiplicationOperator() |
| 33 | + root = mult2 = literals.MultiplicationOperator() |
| 34 | + plus = literals.AdditionOperator() |
| 35 | + minus = literals.SubtractionOperator() |
| 36 | + |
| 37 | + # Create the equation c*(v1+v3)*(v4-v2) |
| 38 | + plus.addLiteral(v1) |
| 39 | + plus.addLiteral(v3) |
| 40 | + minus.addLiteral(v4) |
| 41 | + minus.addLiteral(v2) |
| 42 | + mult.addLiteral(plus) |
| 43 | + mult.addLiteral(minus) |
| 44 | + mult2.addLiteral(mult) |
| 45 | + mult2.addLiteral(c) |
| 46 | + |
| 47 | + # Set the values of the variables. |
| 48 | + # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 |
| 49 | + v1.setValue(1) |
| 50 | + v2.setValue(2) |
| 51 | + v3.setValue(3) |
| 52 | + v4.setValue(4) |
| 53 | + c.setValue(2.5) |
| 54 | + |
| 55 | + # Make an equation and test |
| 56 | + eq = Equation("eq", mult2) |
| 57 | + |
| 58 | + assert eq._value is None |
| 59 | + args = eq.args |
| 60 | + assert v1 in args |
| 61 | + assert v2 in args |
| 62 | + assert v3 in args |
| 63 | + assert v4 in args |
| 64 | + assert c not in args |
| 65 | + assert root is eq.root |
| 66 | + |
| 67 | + assert v1 is eq.v1 |
| 68 | + assert v2 is eq.v2 |
| 69 | + assert v3 is eq.v3 |
| 70 | + assert v4 is eq.v4 |
| 71 | + |
| 72 | + assert 20 == eq() # 20 = 2.5*(1+3)*(4-2) |
| 73 | + assert 20 == eq.getValue() # same as above |
| 74 | + assert 20 == eq.value # same as above |
| 75 | + assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2) |
| 76 | + assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0) |
| 77 | + assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0) |
| 78 | + assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0) |
| 79 | + |
| 80 | + # Try some swapping |
| 81 | + eq.swap(v4, v1) |
| 82 | + assert eq._value is None |
| 83 | + assert 15 == eq() # 15 = 2.5*(2+1)*(2-0) |
| 84 | + args = eq.args |
| 85 | + assert v4 not in args |
| 86 | + |
| 87 | + # Try to create a dependency loop |
| 88 | + with pytest.raises(ValueError): |
| 89 | + eq.swap(v1, eq.root) |
| 90 | + |
| 91 | + with pytest.raises(ValueError): |
| 92 | + eq.swap(v1, plus) |
| 93 | + |
| 94 | + with pytest.raises(ValueError): |
| 95 | + eq.swap(v1, minus) |
| 96 | + |
| 97 | + with pytest.raises(ValueError): |
| 98 | + eq.swap(v1, mult) |
| 99 | + |
| 100 | + with pytest.raises(ValueError): |
| 101 | + eq.swap(v1, root) |
| 102 | + |
| 103 | + # Swap the root |
| 104 | + eq.swap(eq.root, v1) |
| 105 | + assert eq._value is None |
| 106 | + assert v1.value, eq() |
| 107 | + |
| 108 | + assert noObserversInGlobalBuilders |
| 109 | + return |
| 110 | + |
| 111 | + |
| 112 | +def testEmbeddedEquation(make_args, noObserversInGlobalBuilders): |
| 113 | + """Test a simple function.""" |
| 114 | + |
| 115 | + # Make some variables |
| 116 | + v1, v2, v3, v4, c = make_args(5) |
| 117 | + c.name = "c" |
| 118 | + c.const = True |
| 119 | + |
| 120 | + # Make some operations |
| 121 | + mult = literals.MultiplicationOperator() |
| 122 | + mult2 = literals.MultiplicationOperator() |
| 123 | + plus = literals.AdditionOperator() |
| 124 | + minus = literals.SubtractionOperator() |
| 125 | + |
| 126 | + # Create the equation c*(v1+v3)*(v4-v2) |
| 127 | + plus.addLiteral(v1) |
| 128 | + plus.addLiteral(v3) |
| 129 | + minus.addLiteral(v4) |
| 130 | + minus.addLiteral(v2) |
| 131 | + mult.addLiteral(plus) |
| 132 | + mult.addLiteral(minus) |
| 133 | + mult2.addLiteral(mult) |
| 134 | + mult2.addLiteral(c) |
| 135 | + |
| 136 | + # Set the values of the variables. |
| 137 | + # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 |
| 138 | + v1.setValue(1) |
| 139 | + v2.setValue(2) |
| 140 | + v3.setValue(3) |
| 141 | + v4.setValue(4) |
| 142 | + c.setValue(2.5) |
| 143 | + |
| 144 | + # Make an equation and test |
| 145 | + root = Equation("root", mult2) |
| 146 | + eq = Equation("eq", root) |
| 147 | + |
| 148 | + assert eq._value is None |
| 149 | + args = eq.args |
| 150 | + assert v1 in args |
| 151 | + assert v2 in args |
| 152 | + assert v3 in args |
| 153 | + assert v4 in args |
| 154 | + assert c not in args |
| 155 | + assert root is eq.root |
| 156 | + |
| 157 | + assert v1 is eq.v1 |
| 158 | + assert v2 is eq.v2 |
| 159 | + assert v3 is eq.v3 |
| 160 | + assert v4 is eq.v4 |
| 161 | + |
| 162 | + # Make sure the right messages get sent |
| 163 | + v1.value = 0 |
| 164 | + assert root._value is None |
| 165 | + assert eq._value is None |
| 166 | + v1.value = 1 |
| 167 | + |
| 168 | + assert 20 == eq() # 20 = 2.5*(1+3)*(4-2) |
| 169 | + assert 20 == eq.getValue() # same as above |
| 170 | + assert 20 == eq.value # same as above |
| 171 | + assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2) |
| 172 | + assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0) |
| 173 | + assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0) |
| 174 | + assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0) |
| 175 | + |
| 176 | + # Try some swapping. |
| 177 | + eq.swap(v4, v1) |
| 178 | + assert eq._value is None |
| 179 | + assert 15 == eq() # 15 = 2.5*(2+1)*(2-0) |
| 180 | + args = eq.args |
| 181 | + assert v4 not in args |
| 182 | + |
| 183 | + # Try to create a dependency loop |
| 184 | + with pytest.raises(ValueError): |
| 185 | + eq.swap(v1, eq.root) |
| 186 | + |
| 187 | + with pytest.raises(ValueError): |
| 188 | + eq.swap(v1, plus) |
| 189 | + |
| 190 | + with pytest.raises(ValueError): |
| 191 | + eq.swap(v1, minus) |
| 192 | + |
| 193 | + with pytest.raises(ValueError): |
| 194 | + eq.swap(v1, mult) |
| 195 | + |
| 196 | + with pytest.raises(ValueError): |
| 197 | + eq.swap(v1, root) |
| 198 | + |
| 199 | + # Swap the root |
| 200 | + eq.swap(eq.root, v1) |
| 201 | + assert eq._value is None |
| 202 | + assert v1.value == eq() |
| 203 | + |
| 204 | + assert noObserversInGlobalBuilders |
| 205 | + return |
0 commit comments