@@ -955,3 +955,126 @@ def test_metadata(self, asset):
955
955
)
956
956
assert decoder .metadata .sample_rate == asset .sample_rate
957
957
assert decoder .metadata .num_channels == asset .num_channels
958
+
959
+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
960
+ def test_error (self , asset ):
961
+ decoder = AudioDecoder (asset .path )
962
+
963
+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
964
+ decoder .get_samples_played_in_range (start_seconds = - 1300 )
965
+
966
+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
967
+ decoder .get_samples_played_in_range (start_seconds = 9999 )
968
+
969
+ with pytest .raises (ValueError , match = "Invalid start seconds" ):
970
+ decoder .get_samples_played_in_range (start_seconds = 3 , stop_seconds = 2 )
971
+
972
+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
973
+ @pytest .mark .parametrize ("stop_seconds" , (None , "duration" , 99999999 ))
974
+ def test_get_all_samples (self , asset , stop_seconds ):
975
+ decoder = AudioDecoder (asset .path )
976
+
977
+ if stop_seconds == "duration" :
978
+ stop_seconds = asset .duration_seconds
979
+
980
+ samples = decoder .get_samples_played_in_range (
981
+ start_seconds = 0 , stop_seconds = stop_seconds
982
+ )
983
+
984
+ reference_frames = asset .get_frame_data_by_range (
985
+ start = 0 , stop = asset .get_frame_index (pts_seconds = asset .duration_seconds ) + 1
986
+ )
987
+
988
+ torch .testing .assert_close (samples .data , reference_frames )
989
+ assert samples .sample_rate == asset .sample_rate
990
+
991
+ # TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
992
+ expected_pts = (
993
+ 0.072
994
+ if asset is NASA_AUDIO_MP3
995
+ else asset .get_frame_info (idx = 0 ).pts_seconds
996
+ )
997
+ assert samples .pts_seconds == expected_pts
998
+
999
+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1000
+ def test_at_frame_boundaries (self , asset ):
1001
+ decoder = AudioDecoder (asset .path )
1002
+
1003
+ start_frame_index , stop_frame_index = 10 , 40
1004
+ start_seconds = asset .get_frame_info (start_frame_index ).pts_seconds
1005
+ stop_seconds = asset .get_frame_info (stop_frame_index ).pts_seconds
1006
+
1007
+ samples = decoder .get_samples_played_in_range (
1008
+ start_seconds = start_seconds , stop_seconds = stop_seconds
1009
+ )
1010
+
1011
+ reference_frames = asset .get_frame_data_by_range (
1012
+ start = start_frame_index , stop = stop_frame_index
1013
+ )
1014
+
1015
+ assert samples .pts_seconds == start_seconds
1016
+ num_samples = samples .data .shape [1 ]
1017
+ assert (
1018
+ num_samples
1019
+ == reference_frames .shape [1 ]
1020
+ == (stop_seconds - start_seconds ) * decoder .metadata .sample_rate
1021
+ )
1022
+ torch .testing .assert_close (samples .data , reference_frames )
1023
+ assert samples .sample_rate == asset .sample_rate
1024
+
1025
+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1026
+ def test_not_at_frame_boundaries (self , asset ):
1027
+ decoder = AudioDecoder (asset .path )
1028
+
1029
+ start_frame_index , stop_frame_index = 10 , 40
1030
+ start_frame_info = asset .get_frame_info (start_frame_index )
1031
+ stop_frame_info = asset .get_frame_info (stop_frame_index )
1032
+ start_seconds = start_frame_info .pts_seconds + (
1033
+ start_frame_info .duration_seconds / 2
1034
+ )
1035
+ stop_seconds = stop_frame_info .pts_seconds + (
1036
+ stop_frame_info .duration_seconds / 2
1037
+ )
1038
+ samples = decoder .get_samples_played_in_range (
1039
+ start_seconds = start_seconds , stop_seconds = stop_seconds
1040
+ )
1041
+
1042
+ reference_frames = asset .get_frame_data_by_range (
1043
+ start = start_frame_index , stop = stop_frame_index + 1
1044
+ )
1045
+
1046
+ assert samples .pts_seconds == start_seconds
1047
+ num_samples = samples .data .shape [1 ]
1048
+ assert num_samples < reference_frames .shape [1 ]
1049
+ assert (
1050
+ num_samples == (stop_seconds - start_seconds ) * decoder .metadata .sample_rate
1051
+ )
1052
+ assert samples .sample_rate == asset .sample_rate
1053
+
1054
+ @pytest .mark .parametrize ("asset" , (NASA_AUDIO , NASA_AUDIO_MP3 ))
1055
+ def test_start_equals_stop (self , asset ):
1056
+ decoder = AudioDecoder (asset .path )
1057
+ samples = decoder .get_samples_played_in_range (start_seconds = 3 , stop_seconds = 3 )
1058
+ assert samples .data .shape == (0 , 0 )
1059
+
1060
+ def test_frame_start_is_not_zero (self ):
1061
+ # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
1062
+ # So if we request start = 0.05, we shouldn't be truncating anything.
1063
+ #
1064
+ # [1] well, really it's at 0.138125, not 0.072 (see
1065
+ # https://github.com/pytorch/torchcodec/issues/553), but for the purpose
1066
+ # of this test it doesn't matter.
1067
+
1068
+ asset = NASA_AUDIO_MP3
1069
+ start_seconds = 0.05 # this is less than the first frame's pts
1070
+ stop_frame_index = 10
1071
+ stop_seconds = asset .get_frame_info (stop_frame_index ).pts_seconds
1072
+
1073
+ decoder = AudioDecoder (asset .path )
1074
+
1075
+ samples = decoder .get_samples_played_in_range (
1076
+ start_seconds = start_seconds , stop_seconds = stop_seconds
1077
+ )
1078
+
1079
+ reference_frames = asset .get_frame_data_by_range (start = 0 , stop = stop_frame_index )
1080
+ torch .testing .assert_close (samples .data , reference_frames )
0 commit comments