Skip to content

Commit 5da62e7

Browse files
committed
pivoting: Add tol options to _lex_min_ratio_test
1 parent 9cea0e0 commit 5da62e7

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

quantecon/optimize/pivoting.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@jit(nopython=True, cache=True)
15-
def _pivoting(tableau, pivot, pivot_row):
15+
def _pivoting(tableau, pivot_col, pivot_row):
1616
"""
1717
Perform a pivoting step. Modify `tableau` in place.
1818
@@ -21,8 +21,8 @@ def _pivoting(tableau, pivot, pivot_row):
2121
tableau : ndarray(float, ndim=2)
2222
Array containing the tableau.
2323
24-
pivot : scalar(int)
25-
Pivot.
24+
pivot_col : scalar(int)
25+
Pivot column index.
2626
2727
pivot_row : scalar(int)
2828
Pivot row index.
@@ -35,14 +35,14 @@ def _pivoting(tableau, pivot, pivot_row):
3535
"""
3636
nrows, ncols = tableau.shape
3737

38-
pivot_elt = tableau[pivot_row, pivot]
38+
pivot_elt = tableau[pivot_row, pivot_col]
3939
for j in range(ncols):
4040
tableau[pivot_row, j] /= pivot_elt
4141

4242
for i in range(nrows):
4343
if i == pivot_row:
4444
continue
45-
multiplier = tableau[i, pivot]
45+
multiplier = tableau[i, pivot_col]
4646
if multiplier == 0:
4747
continue
4848
for j in range(ncols):
@@ -53,7 +53,8 @@ def _pivoting(tableau, pivot, pivot_row):
5353

5454
@jit(nopython=True, cache=True)
5555
def _min_ratio_test_no_tie_breaking(tableau, pivot, test_col,
56-
argmins, num_candidates):
56+
argmins, num_candidates,
57+
tol_piv, tol_ratio_diff):
5758
"""
5859
Perform the minimum ratio test, without tie breaking, for the
5960
candidate rows in `argmins[:num_candidates]`. Return the number
@@ -78,6 +79,13 @@ def _min_ratio_test_no_tie_breaking(tableau, pivot, test_col,
7879
num_candidates : scalar(int)
7980
Number of candidate rows in `argmins`.
8081
82+
tol_piv : scalar(float)
83+
Pivot tolerance below which a number is considered to be
84+
nonpositive.
85+
86+
tol_ratio_diff : scalar(float)
87+
Tolerance to determine a tie between ratio values.
88+
8189
Returns
8290
-------
8391
num_argmins : scalar(int)
@@ -89,12 +97,12 @@ def _min_ratio_test_no_tie_breaking(tableau, pivot, test_col,
8997

9098
for k in range(num_candidates):
9199
i = argmins[k]
92-
if tableau[i, pivot] <= TOL_PIV: # Treated as nonpositive
100+
if tableau[i, pivot] <= tol_piv: # Treated as nonpositive
93101
continue
94102
ratio = tableau[i, test_col] / tableau[i, pivot]
95-
if ratio > ratio_min + TOL_RATIO_DIFF: # Ratio large for i
103+
if ratio > ratio_min + tol_ratio_diff: # Ratio large for i
96104
continue
97-
elif ratio < ratio_min - TOL_RATIO_DIFF: # Ratio smaller for i
105+
elif ratio < ratio_min - tol_ratio_diff: # Ratio smaller for i
98106
ratio_min = ratio
99107
num_argmins = 1
100108
else: # Ratio equal
@@ -105,7 +113,8 @@ def _min_ratio_test_no_tie_breaking(tableau, pivot, test_col,
105113

106114

107115
@jit(nopython=True, cache=True)
108-
def _lex_min_ratio_test(tableau, pivot, slack_start, argmins):
116+
def _lex_min_ratio_test(tableau, pivot, slack_start, argmins,
117+
tol_piv=TOL_PIV, tol_ratio_diff=TOL_RATIO_DIFF):
109118
"""
110119
Perform the lexico-minimum ratio test.
111120
@@ -124,6 +133,14 @@ def _lex_min_ratio_test(tableau, pivot, slack_start, argmins):
124133
Empty array used to store the row indices. Its length must be no
125134
smaller than the number of the rows of `tableau`.
126135
136+
tol_piv : scalar(float), optional
137+
Pivot tolerance below which a number is considered to be
138+
nonpositive. Default value is {TOL_PIV}.
139+
140+
tol_ratio_diff : scalar(float), optional
141+
Tolerance to determine a tie between ratio values. Default value
142+
is {TOL_RATIO_DIFF}.
143+
127144
Returns
128145
-------
129146
found : bool
@@ -142,17 +159,25 @@ def _lex_min_ratio_test(tableau, pivot, slack_start, argmins):
142159
for i in range(nrows):
143160
argmins[i] = i
144161

145-
num_argmins = _min_ratio_test_no_tie_breaking(tableau, pivot, -1,
146-
argmins, num_candidates)
162+
num_argmins = _min_ratio_test_no_tie_breaking(
163+
tableau, pivot, -1, argmins, num_candidates, tol_piv, tol_ratio_diff
164+
)
147165
if num_argmins == 1:
148166
found = True
149167
elif num_argmins >= 2:
150168
for j in range(slack_start, slack_start+nrows):
151169
if j == pivot:
152170
continue
153-
num_argmins = _min_ratio_test_no_tie_breaking(tableau, pivot, j,
154-
argmins, num_argmins)
171+
num_argmins = _min_ratio_test_no_tie_breaking(
172+
tableau, pivot, j, argmins, num_argmins,
173+
tol_piv, tol_ratio_diff
174+
)
155175
if num_argmins == 1:
156176
found = True
157177
break
158178
return found, argmins[0]
179+
180+
181+
_lex_min_ratio_test.__doc__ = _lex_min_ratio_test.__doc__.format(
182+
TOL_PIV=TOL_PIV, TOL_RATIO_DIFF=TOL_RATIO_DIFF
183+
)

0 commit comments

Comments
 (0)