@@ -787,3 +787,107 @@ def test_low_pass_filter(alpha):
787
787
f"The filtered value at index { i } is not the expected value. "
788
788
f"Expected: { expected } , Actual: { filtered_func .source [i ][1 ]} "
789
789
)
790
+
791
+
792
+ def test_average_function_ndarray ():
793
+
794
+ dummy_function = Function (
795
+ source = [
796
+ [0 , 0 ],
797
+ [1 , 1 ],
798
+ [2 , 0 ],
799
+ [3 , 1 ],
800
+ [4 , 0 ],
801
+ [5 , 1 ],
802
+ [6 , 0 ],
803
+ [7 , 1 ],
804
+ [8 , 0 ],
805
+ [9 , 1 ],
806
+ ],
807
+ inputs = ["x" ],
808
+ outputs = ["y" ],
809
+ )
810
+ avg_function = dummy_function .average_function ()
811
+
812
+ assert isinstance (avg_function , Function )
813
+ assert np .isclose (avg_function (0 ), 0 )
814
+ assert np .isclose (avg_function (9 ), 0.5 )
815
+
816
+
817
+ def test_average_function_callable ():
818
+
819
+ dummy_function = Function (lambda x : 2 )
820
+ avg_function = dummy_function .average_function (lower = 0 )
821
+
822
+ assert isinstance (avg_function , Function )
823
+ assert np .isclose (avg_function (1 ), 2 )
824
+ assert np .isclose (avg_function (9 ), 2 )
825
+
826
+
827
+ @pytest .mark .parametrize (
828
+ "lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive" ,
829
+ [
830
+ (0 , 10 , 100 , 1 , 0.5 , True , True ),
831
+ (0 , 10 , 100 , 1 , 0.5 , True , False ),
832
+ (0 , 10 , 100 , 1 , 0.5 , False , True ),
833
+ (0 , 10 , 100 , 1 , 0.5 , False , False ),
834
+ (0 , 20 , 200 , 2 , 1 , True , True ),
835
+ ],
836
+ )
837
+ def test_short_time_fft (
838
+ lower , upper , sampling_frequency , window_size , step_size , remove_dc , only_positive
839
+ ):
840
+ """Test the short_time_fft method of the Function class.
841
+
842
+ Parameters
843
+ ----------
844
+ lower : float
845
+ Lower bound of the time range.
846
+ upper : float
847
+ Upper bound of the time range.
848
+ sampling_frequency : float
849
+ Sampling frequency at which to perform the Fourier transform.
850
+ window_size : float
851
+ Size of the window for the STFT, in seconds.
852
+ step_size : float
853
+ Step size for the window, in seconds.
854
+ remove_dc : bool
855
+ If True, the DC component is removed from each window before
856
+ computing the Fourier transform.
857
+ only_positive: bool
858
+ If True, only the positive frequencies are returned.
859
+ """
860
+ # Generate a test signal
861
+ t = np .linspace (lower , upper , int ((upper - lower ) * sampling_frequency ))
862
+ signal = np .sin (2 * np .pi * 5 * t ) # 5 Hz sine wave
863
+ func = Function (np .column_stack ((t , signal )))
864
+
865
+ # Perform STFT
866
+ stft_results = func .short_time_fft (
867
+ lower = lower ,
868
+ upper = upper ,
869
+ sampling_frequency = sampling_frequency ,
870
+ window_size = window_size ,
871
+ step_size = step_size ,
872
+ remove_dc = remove_dc ,
873
+ only_positive = only_positive ,
874
+ )
875
+
876
+ # Check the results
877
+ assert isinstance (stft_results , list )
878
+ assert all (isinstance (f , Function ) for f in stft_results )
879
+
880
+ for f in stft_results :
881
+ assert f .get_inputs () == ["Frequency (Hz)" ]
882
+ assert f .get_outputs () == ["Amplitude" ]
883
+ assert f .get_interpolation_method () == "linear"
884
+ assert f .get_extrapolation_method () == "zero"
885
+
886
+ frequencies = f .source [:, 0 ]
887
+ # amplitudes = f.source[:, 1]
888
+
889
+ if only_positive :
890
+ assert np .all (frequencies >= 0 )
891
+ else :
892
+ assert np .all (frequencies >= - sampling_frequency / 2 )
893
+ assert np .all (frequencies <= sampling_frequency / 2 )
0 commit comments