Skip to content

Commit d4c0717

Browse files
committed
fix/test(replications): fix the algorithm find_position(), and add unit testing for it
1 parent 6b50c19 commit d4c0717

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

simulation/replications.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,17 +352,22 @@ def find_position(self, lst):
352352
int:
353353
Minimum replications required to meet and maintain precision.
354354
"""
355+
# Check if the list is empty or if no value is below the threshold
356+
if not lst or all(x is None or x >= self.half_width_precision
357+
for x in lst):
358+
return None
359+
355360
# Find the first non-None value in the list
356361
start_index = pd.Series(lst).first_valid_index()
357362

358363
# Iterate through the list, stopping when at last point where we still
359364
# have enough elements to look ahead
360365
if start_index is not None:
361-
for i in range(start_index, len(lst) - self.look_ahead + 1):
366+
for i in range(start_index, len(lst) - self.look_ahead):
362367
# Create slice of list with current value + lookahead
363368
# Check if all fall below the desired deviation
364369
if all(value < self.half_width_precision
365-
for value in lst[i:i+self.look_ahead]):
370+
for value in lst[i:i+self.look_ahead+1]):
366371
# Add one, so it is the number of reps, as is zero-indexed
367372
return i + 1
368373
return None

tests/test_unittest_replications.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,37 @@ def test_tabulizer_summary_table():
210210
assert df['lower_ci'].tolist() == [3, 8, 13]
211211
assert df['upper_ci'].tolist() == [7, 12, 17]
212212
assert df['deviation'].tolist() == [0.1, 0.2, 0.3]
213+
214+
215+
@pytest.mark.parametrize('lst, exp, look_ahead', [
216+
([None, None, 0.8, 0.4, 0.3], 4, 0), # Normal case
217+
([0.4, 0.3, 0.2, 0.1], 1, 0), # No None values
218+
([0.8, 0.9, 0.8, 0.7], None, 0), # No values below threshold
219+
([None, None, None, None], None, 0), # No values
220+
([], None, 0), # Empty list
221+
([None, None, 0.8, 0.8, 0.3, 0.3, 0.3], None, 3), # Not full lookahead
222+
([None, None, 0.8, 0.8, 0.3, 0.3, 0.3, 0.3], 5, 3) # Meets lookahead
223+
])
224+
def test_find_position(lst, exp, look_ahead):
225+
"""
226+
Test the find_position() method from ReplicationsAlgorithm.
227+
228+
Arguments:
229+
lst (list)
230+
List of values to input to find_position().
231+
exp (float)
232+
Expected result from find_position().
233+
look_ahead (int)
234+
Number of extra positions to check that they also fall under the
235+
threshold.
236+
"""
237+
# Set threshold to 0.5, with provided look_ahead
238+
alg = ReplicationsAlgorithm(half_width_precision=0.5,
239+
look_ahead=look_ahead)
240+
241+
# Get result from algorithm and compare to expected
242+
result = alg.find_position(lst)
243+
assert result == exp, (
244+
f'Ran find_position on: {lst} (threshold 0.5, look-ahead ' +
245+
f'{look_ahead}). Expected {exp}, but got {result}.'
246+
)

0 commit comments

Comments
 (0)