Skip to content

Commit 6863d6e

Browse files
Lucas-PratesGui-FernandesBR
authored andcommitted
TST: complementing tests for sensitivity analysis and removing duplicate piece of code.
1 parent c07b50b commit 6863d6e

File tree

2 files changed

+72
-17
lines changed

2 files changed

+72
-17
lines changed

rocketpy/sensitivity/sensitivity_model.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,6 @@ def set_target_variables_nominal(self, target_variables_nominal_value):
140140
self.target_variables_info[target_variable]["nominal_value"] = (
141141
target_variables_nominal_value[i]
142142
)
143-
for i, target_variable in enumerate(self.target_variables_names):
144-
self.target_variables_info[target_variable]["nominal_value"] = (
145-
target_variables_nominal_value[i]
146-
)
147143

148144
self._nominal_target_passed = True
149145

@@ -356,12 +352,12 @@ def __check_requirements(self):
356352
version = ">=0" if not version else version
357353
try:
358354
check_requirement_version(module_name, version)
359-
except (ValueError, ImportError) as e:
355+
except (ValueError, ImportError) as e: # pragma: no cover
360356
has_error = True
361357
print(
362358
f"The following error occurred while importing {module_name}: {e}"
363359
)
364-
if has_error:
360+
if has_error: # pragma: no cover
365361
print(
366362
"Given the above errors, some methods may not work. Please run "
367363
+ "'pip install rocketpy[sensitivity]' to install extra requirements."

tests/unit/test_sensitivity.py

+70-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
from unittest.mock import patch
2+
13
import numpy as np
24
import pytest
35

46
from rocketpy.sensitivity import SensitivityModel
57

6-
# TODO: for some weird reason, these tests are not passing in the CI, but
7-
# passing locally. Need to investigate why.
8-
98

10-
@pytest.mark.skip(reason="legacy test")
119
def test_initialization():
1210
parameters_names = ["param1", "param2"]
1311
target_variables_names = ["target1", "target2"]
@@ -21,7 +19,6 @@ def test_initialization():
2119
assert not model._fitted
2220

2321

24-
@pytest.mark.skip(reason="legacy test")
2522
def test_set_parameters_nominal():
2623
parameters_names = ["param1", "param2"]
2724
target_variables_names = ["target1", "target2"]
@@ -35,8 +32,16 @@ def test_set_parameters_nominal():
3532
assert model.parameters_info["param1"]["nominal_mean"] == 1.0
3633
assert model.parameters_info["param2"]["nominal_sd"] == 0.2
3734

35+
# check dimensions mismatch error raise
36+
incorrect_nominal_mean = np.array([1.0])
37+
with pytest.raises(ValueError):
38+
model.set_parameters_nominal(incorrect_nominal_mean, parameters_nominal_sd)
39+
40+
incorrect_nominal_sd = np.array([0.1])
41+
with pytest.raises(ValueError):
42+
model.set_parameters_nominal(parameters_nominal_mean, incorrect_nominal_sd)
43+
3844

39-
@pytest.mark.skip(reason="legacy test")
4045
def test_set_target_variables_nominal():
4146
parameters_names = ["param1", "param2"]
4247
target_variables_names = ["target1", "target2"]
@@ -49,9 +54,13 @@ def test_set_target_variables_nominal():
4954
assert model.target_variables_info["target1"]["nominal_value"] == 10.0
5055
assert model.target_variables_info["target2"]["nominal_value"] == 20.0
5156

57+
# check dimensions mismatch error raise
58+
incorrect_nominal_value = np.array([10.0])
59+
with pytest.raises(ValueError):
60+
model.set_target_variables_nominal(incorrect_nominal_value)
61+
5262

53-
@pytest.mark.skip(reason="legacy test")
54-
def test_fit_method():
63+
def test_fit_method_one_target():
5564
parameters_names = ["param1", "param2"]
5665
target_variables_names = ["target1"]
5766
model = SensitivityModel(parameters_names, target_variables_names)
@@ -65,7 +74,20 @@ def test_fit_method():
6574
assert model.number_of_samples == 3
6675

6776

68-
@pytest.mark.skip(reason="legacy test")
77+
def test_fit_method_multiple_target():
78+
parameters_names = ["param1", "param2"]
79+
target_variables_names = ["target1", "target2"]
80+
model = SensitivityModel(parameters_names, target_variables_names)
81+
82+
parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
83+
target_data = np.array([[10.0, 12.0, 14.0], [11.0, 13.0, 17.0]]).T
84+
85+
model.fit(parameters_matrix, target_data)
86+
87+
assert model._fitted
88+
assert model.number_of_samples == 3
89+
90+
6991
def test_fit_raises_error_on_mismatched_dimensions():
7092
parameters_names = ["param1", "param2"]
7193
target_variables_names = ["target1"]
@@ -78,7 +100,6 @@ def test_fit_raises_error_on_mismatched_dimensions():
78100
model.fit(parameters_matrix, target_data)
79101

80102

81-
@pytest.mark.skip(reason="legacy test")
82103
def test_check_conformity():
83104
parameters_names = ["param1", "param2"]
84105
target_variables_names = ["target1", "target2"]
@@ -90,7 +111,6 @@ def test_check_conformity():
90111
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
91112

92113

93-
@pytest.mark.skip(reason="legacy test")
94114
def test_check_conformity_raises_error():
95115
parameters_names = ["param1", "param2"]
96116
target_variables_names = ["target1", "target2"]
@@ -101,3 +121,42 @@ def test_check_conformity_raises_error():
101121

102122
with pytest.raises(ValueError):
103123
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
124+
125+
parameters_matrix2 = np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
126+
127+
with pytest.raises(ValueError):
128+
model._SensitivityModel__check_conformity(parameters_matrix2, target_data)
129+
130+
target_data2 = np.array([10.0, 12.0])
131+
132+
with pytest.raises(ValueError):
133+
model._SensitivityModel__check_conformity(parameters_matrix, target_data2)
134+
135+
target_variables_names = ["target1"]
136+
model = SensitivityModel(parameters_names, target_variables_names)
137+
138+
target_data = np.array([[10.0, 20.0], [12.0, 22.0], [14.0, 24.0]])
139+
140+
with pytest.raises(ValueError):
141+
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
142+
143+
144+
@patch("matplotlib.pyplot.show")
145+
def test_prints_and_plots(mock_show): # pylint: disable=unused-argument
146+
parameters_names = ["param1", "param2"]
147+
target_variables_names = ["target1"]
148+
model = SensitivityModel(parameters_names, target_variables_names)
149+
150+
parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
151+
target_data = np.array([10.0, 12.0, 14.0])
152+
153+
# tests if an error is raised if summary is called before print
154+
with pytest.raises(ValueError):
155+
model.info()
156+
157+
model.fit(parameters_matrix, target_data)
158+
assert model.all_info() is None
159+
160+
nominal_target = np.array([12.0])
161+
model.set_target_variables_nominal(nominal_target)
162+
assert model.all_info() is None

0 commit comments

Comments
 (0)