Skip to content

Commit 9bd4383

Browse files
TST: adds more unit tests to the codebase
MNT: linters TST: complementing tests for sensitivity analysis and removing duplicate piece of code. DEV: add pragma comments to exclude specific lines from coverage MNT: fix pylint error
1 parent 6c656f5 commit 9bd4383

File tree

15 files changed

+440
-32
lines changed

15 files changed

+440
-32
lines changed

rocketpy/environment/environment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2577,7 +2577,7 @@ def set_earth_geometry(self, datum):
25772577
}
25782578
try:
25792579
return ellipsoid[datum]
2580-
except KeyError as e:
2580+
except KeyError as e: # pragma: no cover
25812581
available_datums = ', '.join(ellipsoid.keys())
25822582
raise AttributeError(
25832583
f"The reference system '{datum}' is not recognized. Please use one of "

rocketpy/mathutils/function.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3119,12 +3119,12 @@ def compose(self, func, extrapolate=False):
31193119
The result of inputting the function into the function.
31203120
"""
31213121
# Check if the input is a function
3122-
if not isinstance(func, Function):
3122+
if not isinstance(func, Function): # pragma: no cover
31233123
raise TypeError("Input must be a Function object.")
31243124

31253125
if isinstance(self.source, np.ndarray) and isinstance(func.source, np.ndarray):
31263126
# Perform bounds check for composition
3127-
if not extrapolate:
3127+
if not extrapolate: # pragma: no cover
31283128
if func.min < self.x_initial or func.max > self.x_final:
31293129
raise ValueError(
31303130
f"Input Function image {func.min, func.max} must be within "
@@ -3197,7 +3197,7 @@ def savetxt(
31973197

31983198
# create the datapoints
31993199
if callable(self.source):
3200-
if lower is None or upper is None or samples is None:
3200+
if lower is None or upper is None or samples is None: # pragma: no cover
32013201
raise ValueError(
32023202
"If the source is a callable, lower, upper and samples"
32033203
+ " must be provided."
@@ -3323,6 +3323,7 @@ def __validate_inputs(self, inputs):
33233323
if isinstance(inputs, (list, tuple)):
33243324
if len(inputs) == 1:
33253325
return inputs
3326+
# pragma: no cover
33263327
raise ValueError(
33273328
"Inputs must be a string or a list of strings with "
33283329
"the length of the domain dimension."
@@ -3335,6 +3336,7 @@ def __validate_inputs(self, inputs):
33353336
isinstance(i, str) for i in inputs
33363337
):
33373338
return inputs
3339+
# pragma: no cover
33383340
raise ValueError(
33393341
"Inputs must be a list of strings with "
33403342
"the length of the domain dimension."

rocketpy/rocket/aero_surface/nose_cone.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def bluffness(self, value):
317317
raise ValueError(
318318
"Parameter 'bluffness' must be None or 0 when using a nose cone kind 'powerseries'."
319319
)
320-
if value is not None and not (0 <= value <= 1): # pragma: no cover
320+
if value is not None and not 0 <= value <= 1: # pragma: no cover
321321
raise ValueError(
322322
f"Bluffness ratio of {value} is out of range. "
323323
"It must be between 0 and 1."

rocketpy/sensitivity/sensitivity_model.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,6 @@ def set_target_variables_nominal(self, target_variables_nominal_value):
140140
self.target_variables_info[target_variable]["nominal_value"] = (
141141
target_variables_nominal_value[i]
142142
)
143-
for i, target_variable in enumerate(self.target_variables_names):
144-
self.target_variables_info[target_variable]["nominal_value"] = (
145-
target_variables_nominal_value[i]
146-
)
147143

148144
self._nominal_target_passed = True
149145

@@ -356,12 +352,12 @@ def __check_requirements(self):
356352
version = ">=0" if not version else version
357353
try:
358354
check_requirement_version(module_name, version)
359-
except (ValueError, ImportError) as e:
355+
except (ValueError, ImportError) as e: # pragma: no cover
360356
has_error = True
361357
print(
362358
f"The following error occurred while importing {module_name}: {e}"
363359
)
364-
if has_error:
360+
if has_error: # pragma: no cover
365361
print(
366362
"Given the above errors, some methods may not work. Please run "
367363
+ "'pip install rocketpy[sensitivity]' to install extra requirements."

rocketpy/simulation/flight.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
615615
self.env = environment
616616
self.rocket = rocket
617617
self.rail_length = rail_length
618-
if self.rail_length <= 0:
618+
if self.rail_length <= 0: # pragma: no cover
619619
raise ValueError("Rail length must be a positive value.")
620620
self.parachutes = self.rocket.parachutes[:]
621621
self.inclination = inclination
@@ -951,7 +951,7 @@ def __simulate(self, verbose):
951951
for t_root in t_roots
952952
if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1
953953
]
954-
if len(valid_t_root) > 1:
954+
if len(valid_t_root) > 1: # pragma: no cover
955955
raise ValueError(
956956
"Multiple roots found when solving for impact time."
957957
)
@@ -1226,7 +1226,7 @@ def __init_controllers(self):
12261226
self._controllers = self.rocket._controllers[:]
12271227
self.sensors = self.rocket.sensors.get_components()
12281228
if self._controllers or self.sensors:
1229-
if self.time_overshoot:
1229+
if self.time_overshoot: # pragma: no cover
12301230
self.time_overshoot = False
12311231
warnings.warn(
12321232
"time_overshoot has been set to False due to the presence "
@@ -1266,7 +1266,7 @@ def __set_ode_solver(self, solver):
12661266
else:
12671267
try:
12681268
self._solver = ODE_SOLVER_MAP[solver]
1269-
except KeyError as e:
1269+
except KeyError as e: # pragma: no cover
12701270
raise ValueError(
12711271
f"Invalid ``ode_solver`` input: {solver}. "
12721272
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
@@ -1398,7 +1398,7 @@ def udot_rail1(self, t, u, post_processing=False):
13981398

13991399
return [vx, vy, vz, ax, ay, az, 0, 0, 0, 0, 0, 0, 0]
14001400

1401-
def udot_rail2(self, t, u, post_processing=False):
1401+
def udot_rail2(self, t, u, post_processing=False): # pragma: no cover
14021402
"""[Still not implemented] Calculates derivative of u state vector with
14031403
respect to time when rocket is flying in 3 DOF motion in the rail.
14041404

rocketpy/tools.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -978,9 +978,8 @@ def wrapper(*args, **kwargs):
978978
for i in range(max_attempts):
979979
try:
980980
return func(*args, **kwargs)
981-
except (
982-
Exception
983-
) as e: # pragma: no cover # pylint: disable=broad-except
981+
# pylint: disable=broad-except
982+
except Exception as e: # pragma: no cover
984983
if i == max_attempts - 1:
985984
raise e from None
986985
delay = min(delay * 2, max_delay)

tests/fixtures/surfaces/surface_fixtures.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import pytest
22

3-
from rocketpy import NoseCone, RailButtons, Tail, TrapezoidalFins
4-
from rocketpy.rocket.aero_surface.fins.free_form_fins import FreeFormFins
3+
from rocketpy.rocket.aero_surface import (
4+
EllipticalFins,
5+
FreeFormFins,
6+
NoseCone,
7+
RailButtons,
8+
Tail,
9+
TrapezoidalFins,
10+
)
511

612

713
@pytest.fixture
@@ -94,3 +100,16 @@ def calisto_rail_buttons():
94100
angular_position=45,
95101
name="Rail Buttons",
96102
)
103+
104+
105+
@pytest.fixture
106+
def elliptical_fin_set():
107+
return EllipticalFins(
108+
n=4,
109+
span=0.100,
110+
root_chord=0.120,
111+
rocket_radius=0.0635,
112+
cant_angle=0,
113+
airfoil=None,
114+
name="Test Elliptical Fins",
115+
)

tests/unit/test_aero_surfaces.py

+68
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

35
from rocketpy import NoseCone
@@ -71,3 +73,69 @@ def test_powerseries_nosecones_setters(power, invalid_power, new_power):
7173
expected_k = (2 * new_power) / ((2 * new_power) + 1)
7274

7375
assert pytest.approx(test_nosecone.k) == expected_k
76+
77+
78+
@patch("matplotlib.pyplot.show")
79+
def test_elliptical_fins_draw(
80+
mock_show, elliptical_fin_set
81+
): # pylint: disable=unused-argument
82+
assert elliptical_fin_set.plots.draw(filename=None) is None
83+
84+
85+
def test_nose_cone_info(calisto_nose_cone):
86+
assert calisto_nose_cone.info() is None
87+
88+
89+
@patch("matplotlib.pyplot.show")
90+
def test_nose_cone_draw(
91+
mock_show, calisto_nose_cone
92+
): # pylint: disable=unused-argument
93+
assert calisto_nose_cone.draw(filename=None) is None
94+
95+
96+
def test_trapezoidal_fins_info(calisto_trapezoidal_fins):
97+
assert calisto_trapezoidal_fins.info() is None
98+
99+
100+
def test_trapezoidal_fins_tip_chord_setter(calisto_trapezoidal_fins):
101+
calisto_trapezoidal_fins.tip_chord = 0.1
102+
assert calisto_trapezoidal_fins.tip_chord == 0.1
103+
104+
105+
def test_trapezoidal_fins_root_chord_setter(calisto_trapezoidal_fins):
106+
calisto_trapezoidal_fins.root_chord = 0.1
107+
assert calisto_trapezoidal_fins.root_chord == 0.1
108+
109+
110+
def test_trapezoidal_fins_sweep_angle_setter(calisto_trapezoidal_fins):
111+
calisto_trapezoidal_fins.sweep_angle = 0.1
112+
assert calisto_trapezoidal_fins.sweep_angle == 0.1
113+
114+
115+
def test_trapezoidal_fins_sweep_length_setter(calisto_trapezoidal_fins):
116+
calisto_trapezoidal_fins.sweep_length = 0.1
117+
assert calisto_trapezoidal_fins.sweep_length == 0.1
118+
119+
120+
def test_tail_info(calisto_tail):
121+
assert calisto_tail.info() is None
122+
123+
124+
def test_tail_length_setter(calisto_tail):
125+
calisto_tail.length = 0.1
126+
assert calisto_tail.length == 0.1
127+
128+
129+
def test_tail_rocket_radius_setter(calisto_tail):
130+
calisto_tail.rocket_radius = 0.1
131+
assert calisto_tail.rocket_radius == 0.1
132+
133+
134+
def test_tail_bottom_radius_setter(calisto_tail):
135+
calisto_tail.bottom_radius = 0.1
136+
assert calisto_tail.bottom_radius == 0.1
137+
138+
139+
def test_tail_top_radius_setter(calisto_tail):
140+
calisto_tail.top_radius = 0.1
141+
assert calisto_tail.top_radius == 0.1

tests/unit/test_flight_time_nodes.py

+10
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,13 @@ def test_time_node_lt(flight_calisto):
9999
node2 = flight_calisto.TimeNodes.TimeNode(2.0, [], [], [])
100100
assert node1 < node2
101101
assert not node2 < node1
102+
103+
104+
def test_time_node_repr(flight_calisto):
105+
node = flight_calisto.TimeNodes.TimeNode(1.0, [], [], [])
106+
assert isinstance(repr(node), str)
107+
108+
109+
def test_time_nodes_repr(flight_calisto):
110+
time_nodes = flight_calisto.TimeNodes()
111+
assert isinstance(repr(time_nodes), str)

tests/unit/test_function.py

+104
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,107 @@ def test_low_pass_filter(alpha):
787787
f"The filtered value at index {i} is not the expected value. "
788788
f"Expected: {expected}, Actual: {filtered_func.source[i][1]}"
789789
)
790+
791+
792+
def test_average_function_ndarray():
793+
794+
dummy_function = Function(
795+
source=[
796+
[0, 0],
797+
[1, 1],
798+
[2, 0],
799+
[3, 1],
800+
[4, 0],
801+
[5, 1],
802+
[6, 0],
803+
[7, 1],
804+
[8, 0],
805+
[9, 1],
806+
],
807+
inputs=["x"],
808+
outputs=["y"],
809+
)
810+
avg_function = dummy_function.average_function()
811+
812+
assert isinstance(avg_function, Function)
813+
assert np.isclose(avg_function(0), 0)
814+
assert np.isclose(avg_function(9), 0.5)
815+
816+
817+
def test_average_function_callable():
818+
819+
dummy_function = Function(lambda x: 2)
820+
avg_function = dummy_function.average_function(lower=0)
821+
822+
assert isinstance(avg_function, Function)
823+
assert np.isclose(avg_function(1), 2)
824+
assert np.isclose(avg_function(9), 2)
825+
826+
827+
@pytest.mark.parametrize(
828+
"lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive",
829+
[
830+
(0, 10, 100, 1, 0.5, True, True),
831+
(0, 10, 100, 1, 0.5, True, False),
832+
(0, 10, 100, 1, 0.5, False, True),
833+
(0, 10, 100, 1, 0.5, False, False),
834+
(0, 20, 200, 2, 1, True, True),
835+
],
836+
)
837+
def test_short_time_fft(
838+
lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive
839+
):
840+
"""Test the short_time_fft method of the Function class.
841+
842+
Parameters
843+
----------
844+
lower : float
845+
Lower bound of the time range.
846+
upper : float
847+
Upper bound of the time range.
848+
sampling_frequency : float
849+
Sampling frequency at which to perform the Fourier transform.
850+
window_size : float
851+
Size of the window for the STFT, in seconds.
852+
step_size : float
853+
Step size for the window, in seconds.
854+
remove_dc : bool
855+
If True, the DC component is removed from each window before
856+
computing the Fourier transform.
857+
only_positive: bool
858+
If True, only the positive frequencies are returned.
859+
"""
860+
# Generate a test signal
861+
t = np.linspace(lower, upper, int((upper - lower) * sampling_frequency))
862+
signal = np.sin(2 * np.pi * 5 * t) # 5 Hz sine wave
863+
func = Function(np.column_stack((t, signal)))
864+
865+
# Perform STFT
866+
stft_results = func.short_time_fft(
867+
lower=lower,
868+
upper=upper,
869+
sampling_frequency=sampling_frequency,
870+
window_size=window_size,
871+
step_size=step_size,
872+
remove_dc=remove_dc,
873+
only_positive=only_positive,
874+
)
875+
876+
# Check the results
877+
assert isinstance(stft_results, list)
878+
assert all(isinstance(f, Function) for f in stft_results)
879+
880+
for f in stft_results:
881+
assert f.get_inputs() == ["Frequency (Hz)"]
882+
assert f.get_outputs() == ["Amplitude"]
883+
assert f.get_interpolation_method() == "linear"
884+
assert f.get_extrapolation_method() == "zero"
885+
886+
frequencies = f.source[:, 0]
887+
# amplitudes = f.source[:, 1]
888+
889+
if only_positive:
890+
assert np.all(frequencies >= 0)
891+
else:
892+
assert np.all(frequencies >= -sampling_frequency / 2)
893+
assert np.all(frequencies <= sampling_frequency / 2)

0 commit comments

Comments
 (0)