Skip to content

Commit a14811c

Browse files
committed
Fix virtual function issue with CTC decoder (#3230)
Summary: Currently, creating CTCDecoder object by passing a language model to `lm` argument without assigning it to a variable elsewhere causes `RuntimeError: Tried to call pure virtual function "LM::start"`. According to discussions on PyBind11, ( pybind/pybind11#4013 and pybind/pybind11#2839 ) this is due to Python object garbage-collected by the time it's used by code implemented in C++. It attempts to call methods defined in Python, which overrides the base pure virtual function, but the object which provides this override gets deleted by garbage collrector, as the original object is not reference counted. This commit fixes this by simply assiging the given `lm` object as an attribute of CTCDecoder class. Address #3218 Pull Request resolved: #3230 Reviewed By: hwangjeff Differential Revision: D44642989 Pulled By: mthrok fbshipit-source-id: a90af828c7c576bc0eb505164327365ebaadc471
1 parent 3b40834 commit a14811c

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

test/torchaudio_unittest/models/decoder/ctc_decoder_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens):
169169

170170
expected_tokens = ["|", "f", "|", "o", "a"]
171171
self.assertEqual(tokens, expected_tokens)
172+
173+
def test_lm_lifecycle(self):
174+
"""Passing lm without assiging it to a vaiable won't cause runtime error
175+
176+
https://github.com/pytorch/audio/issues/3218
177+
"""
178+
from torchaudio.models.decoder import ctc_decoder
179+
180+
from .ctc_decoder_utils import CustomZeroLM
181+
182+
decoder = ctc_decoder(
183+
lexicon=get_asset_path("decoder/lexicon.txt"),
184+
tokens=get_asset_path("decoder/tokens.txt"),
185+
lm=CustomZeroLM(),
186+
)
187+
decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32))

torchaudio/models/decoder/_ctc_decoder.py

+6
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ def __init__(
269269
)
270270
else:
271271
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
272+
# https://github.com/pytorch/audio/issues/3218
273+
# If lm is passed like rvalue reference, the lm object gets garbage collected,
274+
# and later call to the lm fails.
275+
# This ensures that lm object is not deleted as long as the decoder is alive.
276+
# https://github.com/pybind/pybind11/discussions/4013
277+
self.lm = lm
272278

273279
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
274280
idxs = (g[0] for g in it.groupby(idxs))

0 commit comments

Comments
 (0)