Skip to content

Commit 96db0c5

Browse files
aderbedrAmanda Der Bedrosian
and
Amanda Der Bedrosian
authored
go : add Encoder Begin Callback (#2900)
Adding in EncoderBeginCallback to the Context's Process callback. This optional callback function returns false if computation should be aborted. Co-authored-by: Amanda Der Bedrosian <[email protected]>
1 parent d2aaffd commit 96db0c5

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

bindings/go/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func main() {
3131
if err != nil {
3232
panic(err)
3333
}
34-
if err := context.Process(samples, nil, nil); err != nil {
34+
if err := context.Process(samples, nil, nil, nil); err != nil {
3535
return err
3636
}
3737

bindings/go/examples/go-whisper/process.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
6767
// Process the data
6868
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
6969
context.ResetTimings()
70-
if err := context.Process(data, cb, nil); err != nil {
70+
if err := context.Process(data, nil, cb, nil); err != nil {
7171
return err
7272
}
7373

bindings/go/pkg/whisper/context.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
189189
// Process new sample data and return any errors
190190
func (context *context) Process(
191191
data []float32,
192+
callEncoderBegin EncoderBeginCallback,
192193
callNewSegment SegmentCallback,
193194
callProgress ProgressCallback,
194195
) error {
@@ -203,30 +204,32 @@ func (context *context) Process(
203204
// We don't do parallel processing at the moment
204205
processors := 0
205206
if processors > 1 {
206-
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
207+
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
208+
func(new int) {
209+
if callNewSegment != nil {
210+
num_segments := context.model.ctx.Whisper_full_n_segments()
211+
s0 := num_segments - new
212+
for i := s0; i < num_segments; i++ {
213+
callNewSegment(toSegment(context.model.ctx, i))
214+
}
215+
}
216+
}); err != nil {
217+
return err
218+
}
219+
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
220+
func(new int) {
207221
if callNewSegment != nil {
208222
num_segments := context.model.ctx.Whisper_full_n_segments()
209223
s0 := num_segments - new
210224
for i := s0; i < num_segments; i++ {
211225
callNewSegment(toSegment(context.model.ctx, i))
212226
}
213227
}
214-
}); err != nil {
215-
return err
216-
}
217-
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
218-
if callNewSegment != nil {
219-
num_segments := context.model.ctx.Whisper_full_n_segments()
220-
s0 := num_segments - new
221-
for i := s0; i < num_segments; i++ {
222-
callNewSegment(toSegment(context.model.ctx, i))
228+
}, func(progress int) {
229+
if callProgress != nil {
230+
callProgress(progress)
223231
}
224-
}
225-
}, func(progress int) {
226-
if callProgress != nil {
227-
callProgress(progress)
228-
}
229-
}); err != nil {
232+
}); err != nil {
230233
return err
231234
}
232235

bindings/go/pkg/whisper/context_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,6 @@ func TestProcess(t *testing.T) {
8888
context, err := model.NewContext()
8989
assert.NoError(err)
9090

91-
err = context.Process(data, nil, nil)
91+
err = context.Process(data, nil, nil, nil)
9292
assert.NoError(err)
9393
}

bindings/go/pkg/whisper/interface.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
1616
// processing. It is called during the Process function
1717
type ProgressCallback func(int)
1818

19+
// EncoderBeginCallback is the callback function for checking if we want to
20+
// continue processing. It is called during the Process function
21+
type EncoderBeginCallback func() bool
22+
1923
// Model is the interface to a whisper model. Create a new model with the
2024
// function whisper.New(string)
2125
type Model interface {
@@ -31,7 +35,7 @@ type Model interface {
3135
Languages() []string
3236
}
3337

34-
// Context is the speach recognition context.
38+
// Context is the speech recognition context.
3539
type Context interface {
3640
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
3741
SetTranslate(bool) // Set translate flag
@@ -58,7 +62,7 @@ type Context interface {
5862
// Process mono audio data and return any errors.
5963
// If defined, newly generated segments are passed to the
6064
// callback function during processing.
61-
Process([]float32, SegmentCallback, ProgressCallback) error
65+
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
6266

6367
// After process is called, return segments until the end of the stream
6468
// is reached, when io.EOF is returned.

0 commit comments

Comments
 (0)