From 55f8b92081dc5c7b98ef3a2a52de15966be663dc Mon Sep 17 00:00:00 2001 From: Brandon Lee Date: Thu, 15 Apr 2021 21:58:34 -0400 Subject: [PATCH] fixing phone class (#76) * fixing phone class * adding documentation --- .pre-commit-config.yaml | 4 +- howl/context.py | 68 +++++----- howl/data/dataset/labeller.py | 22 ++-- howl/data/dataset/phone.py | 130 ++++++++++++++----- test/data/{ => dataset}/dataset_test.py | 24 ++-- test/data/dataset/phone_test.py | 78 +++++++++++ test/test_data/pronounciation_dictionary.txt | 4 + 7 files changed, 245 insertions(+), 85 deletions(-) rename test/data/{ => dataset}/dataset_test.py (83%) create mode 100644 test/data/dataset/phone_test.py create mode 100644 test/test_data/pronounciation_dictionary.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f631409..da207e50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,9 @@ repos: rev: 3.8.3 hooks: - id: flake8 - args: [--max-line-length=120] + args: + - --max-line-length=120 + - --ignore=E203 - repo: https://github.com/pre-commit/mirrors-pylint rev: v2.6.0 hooks: diff --git a/howl/context.py b/howl/context.py index c24cd365..4224b508 100644 --- a/howl/context.py +++ b/howl/context.py @@ -1,12 +1,20 @@ import logging from dataclasses import dataclass +from pathlib import Path from typing import List -from howl.data.dataset import (PhonePhrase, PhoneticFrameLabeler, - PronunciationDictionary, WakeWordDataset, - WordFrameLabeler) -from howl.data.searcher import (LabelColoring, PhoneticTranscriptSearcher, - WordTranscriptSearcher) +from howl.data.dataset import ( + PhonePhrase, + PhoneticFrameLabeler, + PronunciationDictionary, + WakeWordDataset, + WordFrameLabeler, +) +from howl.data.searcher import ( + LabelColoring, + PhoneticTranscriptSearcher, + WordTranscriptSearcher, +) from howl.data.tokenize import Vocab from howl.settings import SETTINGS @@ -21,64 +29,58 @@ class WakewordDatasetContext: class InferenceContext: - def __init__(self, - vocab: List[str], - token_type: str = 'phone', - pronounce_dict: PronunciationDictionary = None, - use_blank: bool = False): + def __init__(self, vocab: List[str], token_type: str = "phone", use_blank: bool = False): self.coloring = None self.adjusted_vocab = [] self.num_labels = 0 self.token_type = token_type + self.pronounce_dict = None # break down each vocab into phonemes - if token_type == 'phone': - if pronounce_dict is None: - pronounce_dict = PronunciationDictionary.from_file( - SETTINGS.training.phone_dictionary) + if token_type == "phone": + self.pronounce_dict = PronunciationDictionary.from_file(Path(SETTINGS.training.phone_dictionary)) self.coloring = LabelColoring() for word in vocab: - phone_phrases = pronounce_dict.encode(word) - logging.info( - f'Using phonemes {str(phone_phrases)} for word {word}') - self.add_vocab(x.text for x in phone_phrases) + # TODO:: we currently use single representation for simplicity + phone_phrase = self.pronounce_dict.encode(word)[0] + logging.info(f"Word {word: <10} has phonemes of {str(phone_phrase)}") + self.add_vocab(list(str(phone) for phone in phone_phrase.phones)) - elif token_type == 'word': + elif token_type == "word": self.add_vocab(vocab) # initialize vocab set for the system self.negative_label = len(self.adjusted_vocab) - self.vocab = Vocab({word: idx for idx, word in enumerate( - self.adjusted_vocab)}, oov_token_id=self.negative_label) + self.vocab = Vocab( + {word: idx for idx, word in enumerate(self.adjusted_vocab)}, oov_token_id=self.negative_label + ) # initialize labeler; make sure this is located before adding other labels - if token_type == 'phone': - phone_phrases = [PhonePhrase.from_string( - x) for x in self.adjusted_vocab] - self.labeler = PhoneticFrameLabeler(phone_phrases) - elif token_type == 'word': + if token_type == "phone": + phone_phrases = [PhonePhrase.from_string(x) for x in self.adjusted_vocab] + self.labeler = PhoneticFrameLabeler(self.pronounce_dict, phone_phrases) + elif token_type == "word": self.labeler = WordFrameLabeler(self.vocab) # add negative label - self.add_vocab(['[OOV]']) + self.add_vocab(["[OOV]"]) # initialize TranscriptSearcher with the processed targets - if token_type == 'phone': - self.searcher = PhoneticTranscriptSearcher( - phone_phrases, self.coloring) - elif token_type == 'word': + if token_type == "phone": + self.searcher = PhoneticTranscriptSearcher(phone_phrases, self.coloring) + elif token_type == "word": self.searcher = WordTranscriptSearcher(self.vocab) # add extra label for blank if necessary self.blank_label = -1 if use_blank: self.blank_label = len(self.adjusted_vocab) - self.add_vocab(['[BLANK]']) + self.add_vocab(["[BLANK]"]) for idx, word in enumerate(self.adjusted_vocab): - logging.info(f'target {word:10} is assigned to label {idx}') + logging.info(f"target {word:10} is assigned to label {idx}") def add_vocab(self, vocabs: List[str]): for vocab in vocabs: diff --git a/howl/data/dataset/labeller.py b/howl/data/dataset/labeller.py index 97544eba..dff3845b 100644 --- a/howl/data/dataset/labeller.py +++ b/howl/data/dataset/labeller.py @@ -1,15 +1,11 @@ - -from dataclasses import dataclass from typing import List -from howl.data.dataset.phone import PhonePhrase +from howl.data.dataset.phone import PhonePhrase, PronunciationDictionary from howl.data.tokenize import Vocab from .base import AudioClipMetadata, FrameLabelData -__all__ = ['FrameLabeler', - 'WordFrameLabeler', - 'PhoneticFrameLabeler'] +__all__ = ["FrameLabeler", "WordFrameLabeler", "PhoneticFrameLabeler"] class FrameLabeler: @@ -18,12 +14,17 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: class PhoneticFrameLabeler(FrameLabeler): - def __init__(self, phrases: List[PhonePhrase]): + def __init__(self, pronounce_dict: PronunciationDictionary, phrases: List[PhonePhrase]): + self.pronounce_dict = pronounce_dict self.phrases = phrases def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: frame_labels = dict() + start_timestamp = [] + char_indices = [] + start = 0 + # TODO:: must be pronounciation instead of the transcription pp = PhonePhrase.from_string(metadata.transcription) for idx, phrase in enumerate(self.phrases): while True: @@ -35,7 +36,8 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: start = pp.all_idx_to_transcript_idx(pp.audible_idx_to_all_idx(start)) frame_labels[metadata.end_timestamps[start + len(str(phrase)) - 1]] = idx start += 1 - return FrameLabelData(frame_labels) + + return FrameLabelData(frame_labels, start_timestamp, char_indices) class WordFrameLabeler(FrameLabeler): @@ -45,8 +47,8 @@ def __init__(self, vocab: Vocab, ceil_word_boundary: bool = False): def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: frame_labels = dict() - char_indices = [] start_timestamp = [] + char_indices = [] char_idx = 0 for word in metadata.transcription.split(): @@ -59,7 +61,7 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData: end_timestamp = metadata.end_timestamps[char_idx + word_size - 1] frame_labels[end_timestamp] = label char_indices.append((label, list(range(char_idx, char_idx + word_size)))) - start_timestamp.append((label, metadata.end_timestamps[char_idx-1] if char_idx > 0 else 0.0)) + start_timestamp.append((label, metadata.end_timestamps[char_idx - 1] if char_idx > 0 else 0.0)) char_idx += word_size + 1 # space diff --git a/howl/data/dataset/phone.py b/howl/data/dataset/phone.py index 67942818..208bd37e 100644 --- a/howl/data/dataset/phone.py +++ b/howl/data/dataset/phone.py @@ -1,18 +1,19 @@ import enum from collections import defaultdict from dataclasses import dataclass +from functools import lru_cache from pathlib import Path from typing import List, Mapping -__all__ = ['PhoneEnum', - 'PronunciationDictionary', - 'PhonePhrase'] +from howl.settings import SETTINGS + +__all__ = ["PhoneEnum", "PronunciationDictionary", "PhonePhrase"] class PhoneEnum(enum.Enum): - SILENCE = 'sil' - SILENCE_OPTIONAL = 'sp' - SPEECH_UNKNOWN = 'spn' + SILENCE = "sil" + SILENCE_OPTIONAL = "sp" + SPEECH_UNKNOWN = "spn" @dataclass @@ -21,14 +22,16 @@ class Phone: def __post_init__(self): self.text = self.text.lower().strip() - self.is_speech = self.text not in (PhoneEnum.SILENCE.value, - PhoneEnum.SILENCE_OPTIONAL.value, - PhoneEnum.SPEECH_UNKNOWN.value) + self.is_speech = self.text not in ( + PhoneEnum.SILENCE.value, + PhoneEnum.SILENCE_OPTIONAL.value, + PhoneEnum.SPEECH_UNKNOWN.value, + ) def __str__(self): return self.text - def __eq__(self, other: 'Phone'): + def __eq__(self, other: "Phone"): return other.text == self.text @@ -38,7 +41,7 @@ class PhonePhrase: def __post_init__(self): self.audible_phones = [x for x in self.phones if x.is_speech] - self.audible_transcript = ' '.join(x.text for x in self.audible_phones) + self.audible_transcript = " ".join(x.text for x in self.audible_phones) self.sil_indices = [idx for idx, x in enumerate(self.phones) if not x.is_speech] @property @@ -50,44 +53,109 @@ def from_string(cls, string: str): return cls([Phone(x) for x in string.split()]) def __str__(self): - return ' '.join(x.text for x in self.phones) - - def all_idx_to_transcript_idx(self, index: int) -> int: - return sum(map(len, [phone.text for phone in self.phones[:index]])) + index - - def audible_idx_to_all_idx(self, index: int) -> int: + return " ".join(x.text for x in self.phones) + + def all_idx_to_transcript_idx(self, phone_idx: int) -> int: + """get index based on transcription for the given phone_idx + + pp = PhonePhrase.from_string("abc def ghi") + pp.all_idx_to_transcript_idx(0) # 3 - where the first phone (abc) finishes + pp.all_idx_to_transcript_idx(1) # 7 - where the second phone (def) finishes + pp.all_idx_to_transcript_idx(2) # 11 - where the third phone (ghi) finishes + + Args: + phone_idx (int): target phone idx + + Raises: + ValueError: if phone idx is out of bound + + Returns: + int: transcription idx where the phone at the given phone_idx finishes + """ + if phone_idx >= len(self.phones): + raise ValueError(f"Given phone idx ({phone_idx}) is greater than the number of phones ({len(self.phones)})") + all_idx_without_space = sum(map(len, [phone.text for phone in self.phones[: phone_idx + 1]])) + return all_idx_without_space + phone_idx # add phone_idx for spaces between phones + + def audible_idx_to_all_idx(self, audible_idx: int) -> int: + """convert given audible index to phone index including all non speech phones + + pp = PhonePhrase.from_string("abc sil ghi") + pp.audible_idx_to_all_idx(0) # 0 - where the first audible phone (abc) is located in the whole phrase + pp.audible_idx_to_all_idx(1) # 2 - where the second audible phone (abc) is located in the whole phrase + + Args: + audible_idx (int): audible phone index to convert + + Raises: + ValueError: if audible phone index is out of bound + + Returns: + int: the index of the audible phone in the whole phrase + """ + if audible_idx >= len(self.audible_phones): + raise ValueError( + f"Given audible phone idx ({audible_idx}) is greater than" + "the number of audible phones ({len(self.audible_phones)})" + ) offset = 0 for sil_idx in self.sil_indices: - if sil_idx <= index + offset: + if sil_idx <= audible_idx + offset: offset += 1 - return offset + index + return offset + audible_idx + + def audible_index(self, query: "PhonePhrase", start: int = 0) -> int: + """find the starting audible index of the given phrase in the current phrase + + pp = PhonePhrase.from_string("abc sil ghi") + ghi_pp = PhonePhrase.from_string("ghi") + pp.audible_index(ghi_pp, 0) # 1 - audible index of the query phone (ghi) - def audible_index(self, item: 'PhonePhrase', start: int = 0): - item_len = len(item.audible_phones) + Args: + query (PhonePhrase): phone phrase to be searched + start (int, optional): starting index in the whole phrase. Defaults to 0. + + Raises: + ValueError: when the query phone phrase does not contain any phone + ValueError: when the query phone phrase is not found + + Returns: + int: audible index of the query phone phase + """ + query_len = len(query.audible_phones) + if query_len == 0: + raise ValueError(f"query phrase has empty audible_phones: {query.audible_transcript}") self_len = len(self.audible_phones) - for idx in range(start, self_len - item_len + 1): - if all(x == y for x, y in zip(item.audible_phones, self.audible_phones[idx:idx + item_len])): + for idx in range(start, self_len - query_len + 1): + if all(x == y for x, y in zip(query.audible_phones, self.audible_phones[idx : idx + query_len])): return idx - raise ValueError + raise ValueError(f"query phrase is not found: {query.audible_transcript}") class PronunciationDictionary: def __init__(self, data_dict: Mapping[str, List[PhonePhrase]]): self.word2phone = data_dict + @lru_cache(maxsize=SETTINGS.cache.cache_size) + def __contains__(self, key: str): + return key.strip().lower() in self.word2phone + + @lru_cache(maxsize=SETTINGS.cache.cache_size) def encode(self, word: str) -> List[PhonePhrase]: - return self.word2phone[word.lower().strip()] + word = word.strip().lower() + if word not in self.word2phone: + raise ValueError(f"word is not in the dictionary: {word}") + return self.word2phone[word.strip().lower()] @classmethod def from_file(cls, filename: Path): data = defaultdict(list) with filename.open() as f: for line in f: - if line.startswith(';'): + if line.startswith(";"): + continue + word, pronunciation = line.split(" ", 1) + if len(word) == 0 or len(pronunciation) == 0: continue - try: - word, pronunciation = line.split(' ', 1) - data[word.lower()].append(PhonePhrase(list(map(Phone, pronunciation.strip().lower().split())))) - except: - pass + data[word.lower()].append(PhonePhrase.from_string(pronunciation.strip().lower())) return cls(data) diff --git a/test/data/dataset_test.py b/test/data/dataset/dataset_test.py similarity index 83% rename from test/data/dataset_test.py rename to test/data/dataset/dataset_test.py index 5ef54061..b0855195 100644 --- a/test/data/dataset_test.py +++ b/test/data/dataset/dataset_test.py @@ -1,8 +1,13 @@ import unittest import torch -from howl.data.dataset import (AudioClipExample, AudioClipMetadata, - AudioDataset, DatasetType) + +from howl.data.dataset import ( + AudioClipExample, + AudioClipMetadata, + AudioDataset, + DatasetType, +) from howl.data.searcher import WordTranscriptSearcher from howl.data.tokenize import Vocab from howl.settings import SETTINGS @@ -10,23 +15,23 @@ class TestDataset(AudioDataset[AudioClipMetadata]): """Sample dataset for testing""" + __test__ = False - def __init__(self, - **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) self.sample_rate = SETTINGS.audio.sample_rate audio_data = torch.zeros(self.sample_rate) - metadata1 = AudioClipMetadata(transcription='hello world') + metadata1 = AudioClipMetadata(transcription="hello world") sample1 = AudioClipExample(metadata=metadata1, audio_data=audio_data, sample_rate=self.sample_rate) - metadata2 = AudioClipMetadata(transcription='happy new year') + metadata2 = AudioClipMetadata(transcription="happy new year") sample2 = AudioClipExample(metadata=metadata2, audio_data=audio_data, sample_rate=self.sample_rate) - metadata3 = AudioClipMetadata(transcription='what a beautiful world') + metadata3 = AudioClipMetadata(transcription="what a beautiful world") sample3 = AudioClipExample(metadata=metadata3, audio_data=audio_data, sample_rate=self.sample_rate) self.samples = [sample1, sample2, sample3] @@ -39,7 +44,6 @@ def __getitem__(self, idx) -> AudioClipExample: class TestAudioDataset(unittest.TestCase): - def test_compute_statistics(self): """test compute statistics """ @@ -47,7 +51,7 @@ def test_compute_statistics(self): SETTINGS.training.token_type = "word" SETTINGS.inference_engine.inference_sequence = [0, 1] - vocab = Vocab({"Hello": 0, "World": 1}, oov_token_id=2, oov_word_repr='') + vocab = Vocab({"Hello": 0, "World": 1}, oov_token_id=2, oov_word_repr="") searcher = WordTranscriptSearcher(vocab) @@ -70,5 +74,5 @@ def test_compute_statistics(self): self.assertEqual(stat.vocab_counts["World"], 2) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/data/dataset/phone_test.py b/test/data/dataset/phone_test.py new file mode 100644 index 00000000..76443448 --- /dev/null +++ b/test/data/dataset/phone_test.py @@ -0,0 +1,78 @@ +import unittest +from pathlib import Path + +from howl.data.dataset import PhonePhrase, PronunciationDictionary + + +class TestPhonePhrase(unittest.TestCase): + def test_basic_operations(self): + """test PhonePhase instatniation and conversions to strgin + """ + phone_phrase_str = " abc sil sp spn " + pp = PhonePhrase.from_string(phone_phrase_str) + self.assertEqual(pp.text, phone_phrase_str.strip()) + + def test_idx_operations(self): + """test idx conversions + """ + phone_phrase_str = " abc sil sp spn def ghi sil " + pp = PhonePhrase.from_string(phone_phrase_str) + + self.assertEqual(pp.all_idx_to_transcript_idx(0), 3) + self.assertEqual(pp.all_idx_to_transcript_idx(1), 7) + self.assertEqual(pp.all_idx_to_transcript_idx(2), 10) + self.assertEqual(pp.all_idx_to_transcript_idx(3), 14) + self.assertEqual(pp.all_idx_to_transcript_idx(4), 18) + self.assertEqual(pp.all_idx_to_transcript_idx(5), 22) + self.assertEqual(pp.all_idx_to_transcript_idx(6), 26) + self.assertRaises(ValueError, pp.all_idx_to_transcript_idx, 7) + + abc_pp = PhonePhrase.from_string("abc") + def_pp = PhonePhrase.from_string("def") + def_ghi_pp = PhonePhrase.from_string("def ghi") + jki_pp = PhonePhrase.from_string("jki") + sil_pp = PhonePhrase.from_string("sil") + + self.assertEqual(pp.audible_index(abc_pp), 0) + self.assertEqual(pp.audible_index(def_pp), 1) + self.assertEqual(pp.audible_index(def_ghi_pp), 1) + self.assertRaises(ValueError, pp.audible_index, jki_pp) + self.assertRaises(ValueError, pp.audible_index, sil_pp) + + self.assertEqual(pp.audible_idx_to_all_idx(0), 0) + self.assertEqual(pp.audible_idx_to_all_idx(1), 4) + self.assertEqual(pp.audible_idx_to_all_idx(2), 5) + self.assertRaises(ValueError, pp.audible_idx_to_all_idx, 3) + + +class TestPronounciationDictionary(unittest.TestCase): + def test_basic_operations(self): + """test PronounciationDictionary instatniation and phone phrase retrieval + """ + pronounce_dict_file = Path("test/test_data/pronounciation_dictionary.txt") + pronounce_dict = PronunciationDictionary.from_file(pronounce_dict_file) + + self.assertTrue("hey" in pronounce_dict) + self.assertTrue("HEY" in pronounce_dict) + self.assertTrue(" FIRE " in pronounce_dict) + self.assertFalse(" test " in pronounce_dict) + self.assertFalse("" in pronounce_dict) + + self.assertRaises(ValueError, pronounce_dict.encode, "") + + hey_phrases = pronounce_dict.encode("hey") + self.assertEqual(len(hey_phrases), 1) + self.assertEqual(hey_phrases[0].text, "hh ey1") + + fire_phrases = pronounce_dict.encode("fire") + self.assertEqual(len(fire_phrases), 2) + self.assertEqual(fire_phrases[0].text, "f ay1 er0") + self.assertEqual(fire_phrases[1].text, "f ay1 r") + + fox_phrases = pronounce_dict.encode("fox") + self.assertEqual(len(fox_phrases), 1) + self.assertEqual(fox_phrases[0].text, "f aa1 k s") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_data/pronounciation_dictionary.txt b/test/test_data/pronounciation_dictionary.txt new file mode 100644 index 00000000..c98b63fb --- /dev/null +++ b/test/test_data/pronounciation_dictionary.txt @@ -0,0 +1,4 @@ +HEY HH EY1 +FIRE F AY1 ER0 +FIRE F AY1 R +FOX F AA1 K S