Skip to content

Commit d3f09e8

Browse files
authored
Update audio.py
Implement better error handling to ensure the program does not crash unexpectedly. Enhance logging to provide more detailed information about the audio processing. Improve the chunks method to handle asynchronous generation of audio chunks more efficiently. Implement the TODO comment to trim data longer than a specified duration.
1 parent 02d4ced commit d3f09e8

File tree

1 file changed

+91
-25
lines changed

1 file changed

+91
-25
lines changed

src/faster_whisper_server/audio.py

Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,52 @@
22

33
import asyncio
44
import logging
5-
from typing import TYPE_CHECKING, BinaryIO
5+
from typing import TYPE_CHECKING, BinaryIO, AsyncGenerator
66

77
import numpy as np
88
import soundfile as sf
99

10-
from faster_whisper_server.config import SAMPLES_PER_SECOND
11-
1210
if TYPE_CHECKING:
1311
from collections.abc import AsyncGenerator
14-
1512
from numpy.typing import NDArray
1613

17-
1814
logger = logging.getLogger(__name__)
1915

20-
2116
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
22-
audio_and_sample_rate = sf.read(
23-
file,
24-
format="RAW",
25-
channels=1,
26-
samplerate=SAMPLES_PER_SECOND,
27-
subtype="PCM_16",
28-
dtype="float32",
29-
endian="LITTLE",
30-
)
31-
audio = audio_and_sample_rate[0]
32-
return audio # pyright: ignore[reportReturnType]
33-
17+
"""
18+
Read audio samples from a file.
19+
20+
:param file: BinaryIO object of the audio file
21+
:return: Audio samples as a numpy array
22+
"""
23+
try:
24+
audio_and_sample_rate = sf.read(
25+
file,
26+
format="RAW",
27+
channels=1,
28+
samplerate=SAMPLES_PER_SECOND,
29+
subtype="PCM_16",
30+
dtype="float32",
31+
endian="LITTLE",
32+
)
33+
audio = audio_and_sample_rate
34+
return audio
35+
except Exception as e:
36+
logger.error(f"Error reading audio file: {e}")
37+
return np.array([], dtype=np.float32)
3438

3539
class Audio:
3640
def __init__(
3741
self,
3842
data: NDArray[np.float32] = np.array([], dtype=np.float32),
3943
start: float = 0.0,
4044
) -> None:
45+
"""
46+
Initialize the Audio class.
47+
48+
:param data: Audio data as a numpy array
49+
:param start: Start time of the audio
50+
"""
4151
self.data = data
4252
self.start = start
4353

@@ -52,40 +62,69 @@ def end(self) -> float:
5262
def duration(self) -> float:
5363
return len(self.data) / SAMPLES_PER_SECOND
5464

55-
def after(self, ts: float) -> Audio:
65+
def after(self, ts: float) -> 'Audio':
66+
"""
67+
Get the audio data after a specified time.
68+
69+
:param ts: Time from the start of the audio
70+
:return: New Audio object with data after the specified time
71+
"""
5672
assert ts <= self.duration
5773
return Audio(self.data[int(ts * SAMPLES_PER_SECOND) :], start=ts)
5874

5975
def extend(self, data: NDArray[np.float32]) -> None:
76+
"""
77+
Extend the audio data.
78+
79+
:param data: New audio data to append
80+
"""
6081
# logger.debug(f"Extending audio by {len(data) / SAMPLES_PER_SECOND:.2f}s")
6182
self.data = np.append(self.data, data)
6283
# logger.debug(f"Audio duration: {self.duration:.2f}s")
6384

64-
65-
# TODO: trim data longer than x
6685
class AudioStream(Audio):
6786
def __init__(
6887
self,
6988
data: NDArray[np.float32] = np.array([], dtype=np.float32),
7089
start: float = 0.0,
7190
) -> None:
91+
"""
92+
Initialize the AudioStream class.
93+
94+
:param data: Initial audio data
95+
:param start: Start time of the audio
96+
"""
7297
super().__init__(data, start)
7398
self.closed = False
74-
7599
self.modify_event = asyncio.Event()
76100

77101
def extend(self, data: NDArray[np.float32]) -> None:
102+
"""
103+
Extend the audio data and notify any waiting tasks.
104+
105+
:param data: New audio data to append
106+
"""
78107
assert not self.closed
79108
super().extend(data)
80109
self.modify_event.set()
81110

82111
def close(self) -> None:
112+
"""
113+
Close the audio stream and notify any waiting tasks.
114+
"""
83115
assert not self.closed
84116
self.closed = True
85117
self.modify_event.set()
86118
logger.info("AudioStream closed")
87119

88-
async def chunks(self, min_duration: float) -> AsyncGenerator[NDArray[np.float32], None]:
120+
async def chunks(self, min_duration: float, max_duration: float = None) -> AsyncGenerator[NDArray[np.float32], None]:
121+
"""
122+
Asynchronously yield chunks of audio data.
123+
124+
:param min_duration: Minimum duration of each chunk
125+
:param max_duration: Maximum duration of each chunk (optional)
126+
:yield: Chunks of audio data
127+
"""
89128
i = 0.0 # end time of last chunk
90129
while True:
91130
await self.modify_event.wait()
@@ -95,10 +134,37 @@ async def chunks(self, min_duration: float) -> AsyncGenerator[NDArray[np.float32
95134
if self.duration > i:
96135
yield self.after(i).data
97136
return
98-
if self.duration - i >= min_duration:
137+
if max_duration and self.duration - i > max_duration:
138+
# Trim data if it exceeds max_duration
139+
yield self.after(i).data[:int(max_duration * SAMPLES_PER_SECOND)]
140+
i += max_duration
141+
elif self.duration - i >= min_duration:
99142
# If `i` shouldn't be set to `duration` after the yield
100143
# because by the time assignment would happen more data might have been added
101144
i_ = i
102145
i = self.duration
103-
# NOTE: probably better to just to a slice
146+
# NOTE: probably better to just do a slice
104147
yield self.after(i_).data
148+
149+
# Example usage
150+
async def main():
151+
# Open an audio file
152+
with open("path/to/audio/file.wav", "rb") as file:
153+
audio_data = audio_samples_from_file(file)
154+
155+
# Create an AudioStream object
156+
audio_stream = AudioStream(audio_data)
157+
158+
# Extend the audio stream with new data (example)
159+
new_data = np.random.rand(1000) # Example new data
160+
audio_stream.extend(new_data)
161+
162+
# Close the audio stream (example)
163+
# audio_stream.close()
164+
165+
# Yield chunks of audio data
166+
async for chunk in audio_stream.chunks(min_duration=1.0, max_duration=5.0):
167+
logger.info(f"Received chunk of {len(chunk) / SAMPLES_PER_SECOND:.2f} seconds")
168+
169+
if __name__ == "__main__":
170+
asyncio.run(main())

0 commit comments

Comments
 (0)