Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions episodes/24-diagnosing-issues-improving-robustness.md
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,14 @@ def test_patient_normalise(test, expected, expect_raises):

if expect_raises is not None:
with pytest.raises(expect_raises):
result = patient_normalise(np.array(test))
npt.assert_allclose(result, np.array(expected), rtol=1e-2, atol=1e-2)
patient_normalise(np.array(test))
else:
result = patient_normalise(np.array(test))
npt.assert_allclose(result, np.array(expected), rtol=1e-2, atol=1e-2)
```

Notice that under the `pytest.raises` context manager, it isn't necessary to perform the `npt.assert_allclose` because just the call to `patient_normalise()` will raise the `ValueError`.

Be sure to commit your changes so far and push them to GitHub.

::::::::::::::::::::::::::::::::::::::: challenge
Expand All @@ -629,9 +630,17 @@ You will find the Python function
[`isinstance`](https://docs.python.org/3/library/functions.html#isinstance)
useful here, as well as the Python exception
[`TypeError`](https://docs.python.org/3/library/exceptions.html#TypeError).

You can take this even further if your solution code involves multiple `ValueError`s or `TypeError`'s being raised at different locations.
In this case, you want to make sure the function is raising the precise `Exception` that occurs at a specific point in your function.
You might have noticed that `Exception`s can take a string argument corresponding to a message / description of the exception that has occurred.
The [`pytest.raises()`](https://docs.pytest.org/en/stable/reference/reference.html#pytest.raises) context manager can query this message with its `match=` argument.
See if you can use that to more precisely test the exceptions raised by your function.

Once you are done, commit your new files,
and push the new commits to your remote repository on GitHub.


::::::::::::::: solution

## Solution
Expand Down Expand Up @@ -683,6 +692,11 @@ from inflammation.models import patient_normalise
None,
TypeError,
),
(
[4, 5, 6],
None,
ValueError,
),
(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[0.33, 0.67, 1], [0.67, 0.83, 1], [0.78, 0.89, 1]],
Expand All @@ -695,8 +709,56 @@ def test_patient_normalise(test, expected, expect_raises):
test = np.array(test)
if expect_raises is not None:
with pytest.raises(expect_raises):
result = patient_normalise(test)
npt.assert_allclose(result, np.array(expected), rtol=1e-2, atol=1e-2)
patient_normalise(test)

else:
result = patient_normalise(test)
npt.assert_allclose(result, np.array(expected), rtol=1e-2, atol=1e-2)
...
```

Or, if you decided to match the specific exceptions in `test/test_models.py`:

```python
from inflammation.models import patient_normalise
...
@pytest.mark.parametrize(
"test, expected, expect_raises",
[
...
(
[[-1, 2, 3], [4, 5, 6], [7, 8, 9]],
None,
ValueError('inflammation values should be non-negative'),
),
(
[4, 5, 6],
None,
ValueError('inflammation array should be 2-dimensional'),
),
(
'hello',
None,
TypeError('data input should be ndarray'),
),
(
3,
None,
TypeError('data input should be ndarray'),
),
(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[0.33, 0.67, 1], [0.67, 0.83, 1], [0.78, 0.89, 1]],
None,
)
])
def test_patient_normalise(test, expected, expect_raises):
"""Test normalisation works for arrays of one and positive integers."""
if isinstance(test, list):
test = np.array(test)
if expect_raises is not None:
with pytest.raises(expect_raises, match=str(expect_raises)):
patient_normalise(test)

else:
result = patient_normalise(test)
Expand Down