Skip to content

Commit

Permalink
fixing phone class (#76)
Browse files Browse the repository at this point in the history
* fixing phone class

* adding documentation
  • Loading branch information
ljj7975 authored Apr 16, 2021
1 parent 71c8c15 commit 55f8b92
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 85 deletions.
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ repos:
rev: 3.8.3
hooks:
- id: flake8
args: [--max-line-length=120]
args:
- --max-line-length=120
- --ignore=E203
- repo: https://github.com/pre-commit/mirrors-pylint
rev: v2.6.0
hooks:
Expand Down
68 changes: 35 additions & 33 deletions howl/context.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List

from howl.data.dataset import (PhonePhrase, PhoneticFrameLabeler,
PronunciationDictionary, WakeWordDataset,
WordFrameLabeler)
from howl.data.searcher import (LabelColoring, PhoneticTranscriptSearcher,
WordTranscriptSearcher)
from howl.data.dataset import (
PhonePhrase,
PhoneticFrameLabeler,
PronunciationDictionary,
WakeWordDataset,
WordFrameLabeler,
)
from howl.data.searcher import (
LabelColoring,
PhoneticTranscriptSearcher,
WordTranscriptSearcher,
)
from howl.data.tokenize import Vocab
from howl.settings import SETTINGS

Expand All @@ -21,64 +29,58 @@ class WakewordDatasetContext:


class InferenceContext:
def __init__(self,
vocab: List[str],
token_type: str = 'phone',
pronounce_dict: PronunciationDictionary = None,
use_blank: bool = False):
def __init__(self, vocab: List[str], token_type: str = "phone", use_blank: bool = False):

self.coloring = None
self.adjusted_vocab = []
self.num_labels = 0
self.token_type = token_type
self.pronounce_dict = None

# break down each vocab into phonemes
if token_type == 'phone':
if pronounce_dict is None:
pronounce_dict = PronunciationDictionary.from_file(
SETTINGS.training.phone_dictionary)
if token_type == "phone":
self.pronounce_dict = PronunciationDictionary.from_file(Path(SETTINGS.training.phone_dictionary))

self.coloring = LabelColoring()
for word in vocab:
phone_phrases = pronounce_dict.encode(word)
logging.info(
f'Using phonemes {str(phone_phrases)} for word {word}')
self.add_vocab(x.text for x in phone_phrases)
# TODO:: we currently use single representation for simplicity
phone_phrase = self.pronounce_dict.encode(word)[0]
logging.info(f"Word {word: <10} has phonemes of {str(phone_phrase)}")
self.add_vocab(list(str(phone) for phone in phone_phrase.phones))

elif token_type == 'word':
elif token_type == "word":
self.add_vocab(vocab)

# initialize vocab set for the system
self.negative_label = len(self.adjusted_vocab)
self.vocab = Vocab({word: idx for idx, word in enumerate(
self.adjusted_vocab)}, oov_token_id=self.negative_label)
self.vocab = Vocab(
{word: idx for idx, word in enumerate(self.adjusted_vocab)}, oov_token_id=self.negative_label
)

# initialize labeler; make sure this is located before adding other labels
if token_type == 'phone':
phone_phrases = [PhonePhrase.from_string(
x) for x in self.adjusted_vocab]
self.labeler = PhoneticFrameLabeler(phone_phrases)
elif token_type == 'word':
if token_type == "phone":
phone_phrases = [PhonePhrase.from_string(x) for x in self.adjusted_vocab]
self.labeler = PhoneticFrameLabeler(self.pronounce_dict, phone_phrases)
elif token_type == "word":
self.labeler = WordFrameLabeler(self.vocab)

# add negative label
self.add_vocab(['[OOV]'])
self.add_vocab(["[OOV]"])

# initialize TranscriptSearcher with the processed targets
if token_type == 'phone':
self.searcher = PhoneticTranscriptSearcher(
phone_phrases, self.coloring)
elif token_type == 'word':
if token_type == "phone":
self.searcher = PhoneticTranscriptSearcher(phone_phrases, self.coloring)
elif token_type == "word":
self.searcher = WordTranscriptSearcher(self.vocab)

# add extra label for blank if necessary
self.blank_label = -1
if use_blank:
self.blank_label = len(self.adjusted_vocab)
self.add_vocab(['[BLANK]'])
self.add_vocab(["[BLANK]"])

for idx, word in enumerate(self.adjusted_vocab):
logging.info(f'target {word:10} is assigned to label {idx}')
logging.info(f"target {word:10} is assigned to label {idx}")

def add_vocab(self, vocabs: List[str]):
for vocab in vocabs:
Expand Down
22 changes: 12 additions & 10 deletions howl/data/dataset/labeller.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@

from dataclasses import dataclass
from typing import List

from howl.data.dataset.phone import PhonePhrase
from howl.data.dataset.phone import PhonePhrase, PronunciationDictionary
from howl.data.tokenize import Vocab

from .base import AudioClipMetadata, FrameLabelData

__all__ = ['FrameLabeler',
'WordFrameLabeler',
'PhoneticFrameLabeler']
__all__ = ["FrameLabeler", "WordFrameLabeler", "PhoneticFrameLabeler"]


class FrameLabeler:
Expand All @@ -18,12 +14,17 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData:


class PhoneticFrameLabeler(FrameLabeler):
def __init__(self, phrases: List[PhonePhrase]):
def __init__(self, pronounce_dict: PronunciationDictionary, phrases: List[PhonePhrase]):
self.pronounce_dict = pronounce_dict
self.phrases = phrases

def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData:
frame_labels = dict()
start_timestamp = []
char_indices = []

start = 0
# TODO:: must be pronounciation instead of the transcription
pp = PhonePhrase.from_string(metadata.transcription)
for idx, phrase in enumerate(self.phrases):
while True:
Expand All @@ -35,7 +36,8 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData:
start = pp.all_idx_to_transcript_idx(pp.audible_idx_to_all_idx(start))
frame_labels[metadata.end_timestamps[start + len(str(phrase)) - 1]] = idx
start += 1
return FrameLabelData(frame_labels)

return FrameLabelData(frame_labels, start_timestamp, char_indices)


class WordFrameLabeler(FrameLabeler):
Expand All @@ -45,8 +47,8 @@ def __init__(self, vocab: Vocab, ceil_word_boundary: bool = False):

def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData:
frame_labels = dict()
char_indices = []
start_timestamp = []
char_indices = []

char_idx = 0
for word in metadata.transcription.split():
Expand All @@ -59,7 +61,7 @@ def compute_frame_labels(self, metadata: AudioClipMetadata) -> FrameLabelData:
end_timestamp = metadata.end_timestamps[char_idx + word_size - 1]
frame_labels[end_timestamp] = label
char_indices.append((label, list(range(char_idx, char_idx + word_size))))
start_timestamp.append((label, metadata.end_timestamps[char_idx-1] if char_idx > 0 else 0.0))
start_timestamp.append((label, metadata.end_timestamps[char_idx - 1] if char_idx > 0 else 0.0))

char_idx += word_size + 1 # space

Expand Down
130 changes: 99 additions & 31 deletions howl/data/dataset/phone.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import enum
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import List, Mapping

__all__ = ['PhoneEnum',
'PronunciationDictionary',
'PhonePhrase']
from howl.settings import SETTINGS

__all__ = ["PhoneEnum", "PronunciationDictionary", "PhonePhrase"]


class PhoneEnum(enum.Enum):
SILENCE = 'sil'
SILENCE_OPTIONAL = 'sp'
SPEECH_UNKNOWN = 'spn'
SILENCE = "sil"
SILENCE_OPTIONAL = "sp"
SPEECH_UNKNOWN = "spn"


@dataclass
Expand All @@ -21,14 +22,16 @@ class Phone:

def __post_init__(self):
self.text = self.text.lower().strip()
self.is_speech = self.text not in (PhoneEnum.SILENCE.value,
PhoneEnum.SILENCE_OPTIONAL.value,
PhoneEnum.SPEECH_UNKNOWN.value)
self.is_speech = self.text not in (
PhoneEnum.SILENCE.value,
PhoneEnum.SILENCE_OPTIONAL.value,
PhoneEnum.SPEECH_UNKNOWN.value,
)

def __str__(self):
return self.text

def __eq__(self, other: 'Phone'):
def __eq__(self, other: "Phone"):
return other.text == self.text


Expand All @@ -38,7 +41,7 @@ class PhonePhrase:

def __post_init__(self):
self.audible_phones = [x for x in self.phones if x.is_speech]
self.audible_transcript = ' '.join(x.text for x in self.audible_phones)
self.audible_transcript = " ".join(x.text for x in self.audible_phones)
self.sil_indices = [idx for idx, x in enumerate(self.phones) if not x.is_speech]

@property
Expand All @@ -50,44 +53,109 @@ def from_string(cls, string: str):
return cls([Phone(x) for x in string.split()])

def __str__(self):
return ' '.join(x.text for x in self.phones)

def all_idx_to_transcript_idx(self, index: int) -> int:
return sum(map(len, [phone.text for phone in self.phones[:index]])) + index

def audible_idx_to_all_idx(self, index: int) -> int:
return " ".join(x.text for x in self.phones)

def all_idx_to_transcript_idx(self, phone_idx: int) -> int:
"""get index based on transcription for the given phone_idx
pp = PhonePhrase.from_string("abc def ghi")
pp.all_idx_to_transcript_idx(0) # 3 - where the first phone (abc) finishes
pp.all_idx_to_transcript_idx(1) # 7 - where the second phone (def) finishes
pp.all_idx_to_transcript_idx(2) # 11 - where the third phone (ghi) finishes
Args:
phone_idx (int): target phone idx
Raises:
ValueError: if phone idx is out of bound
Returns:
int: transcription idx where the phone at the given phone_idx finishes
"""
if phone_idx >= len(self.phones):
raise ValueError(f"Given phone idx ({phone_idx}) is greater than the number of phones ({len(self.phones)})")
all_idx_without_space = sum(map(len, [phone.text for phone in self.phones[: phone_idx + 1]]))
return all_idx_without_space + phone_idx # add phone_idx for spaces between phones

def audible_idx_to_all_idx(self, audible_idx: int) -> int:
"""convert given audible index to phone index including all non speech phones
pp = PhonePhrase.from_string("abc sil ghi")
pp.audible_idx_to_all_idx(0) # 0 - where the first audible phone (abc) is located in the whole phrase
pp.audible_idx_to_all_idx(1) # 2 - where the second audible phone (abc) is located in the whole phrase
Args:
audible_idx (int): audible phone index to convert
Raises:
ValueError: if audible phone index is out of bound
Returns:
int: the index of the audible phone in the whole phrase
"""
if audible_idx >= len(self.audible_phones):
raise ValueError(
f"Given audible phone idx ({audible_idx}) is greater than"
"the number of audible phones ({len(self.audible_phones)})"
)
offset = 0
for sil_idx in self.sil_indices:
if sil_idx <= index + offset:
if sil_idx <= audible_idx + offset:
offset += 1
return offset + index
return offset + audible_idx

def audible_index(self, query: "PhonePhrase", start: int = 0) -> int:
"""find the starting audible index of the given phrase in the current phrase
pp = PhonePhrase.from_string("abc sil ghi")
ghi_pp = PhonePhrase.from_string("ghi")
pp.audible_index(ghi_pp, 0) # 1 - audible index of the query phone (ghi)
def audible_index(self, item: 'PhonePhrase', start: int = 0):
item_len = len(item.audible_phones)
Args:
query (PhonePhrase): phone phrase to be searched
start (int, optional): starting index in the whole phrase. Defaults to 0.
Raises:
ValueError: when the query phone phrase does not contain any phone
ValueError: when the query phone phrase is not found
Returns:
int: audible index of the query phone phase
"""
query_len = len(query.audible_phones)
if query_len == 0:
raise ValueError(f"query phrase has empty audible_phones: {query.audible_transcript}")
self_len = len(self.audible_phones)
for idx in range(start, self_len - item_len + 1):
if all(x == y for x, y in zip(item.audible_phones, self.audible_phones[idx:idx + item_len])):
for idx in range(start, self_len - query_len + 1):
if all(x == y for x, y in zip(query.audible_phones, self.audible_phones[idx : idx + query_len])):
return idx
raise ValueError
raise ValueError(f"query phrase is not found: {query.audible_transcript}")


class PronunciationDictionary:
def __init__(self, data_dict: Mapping[str, List[PhonePhrase]]):
self.word2phone = data_dict

@lru_cache(maxsize=SETTINGS.cache.cache_size)
def __contains__(self, key: str):
return key.strip().lower() in self.word2phone

@lru_cache(maxsize=SETTINGS.cache.cache_size)
def encode(self, word: str) -> List[PhonePhrase]:
return self.word2phone[word.lower().strip()]
word = word.strip().lower()
if word not in self.word2phone:
raise ValueError(f"word is not in the dictionary: {word}")
return self.word2phone[word.strip().lower()]

@classmethod
def from_file(cls, filename: Path):
data = defaultdict(list)
with filename.open() as f:
for line in f:
if line.startswith(';'):
if line.startswith(";"):
continue
word, pronunciation = line.split(" ", 1)
if len(word) == 0 or len(pronunciation) == 0:
continue
try:
word, pronunciation = line.split(' ', 1)
data[word.lower()].append(PhonePhrase(list(map(Phone, pronunciation.strip().lower().split()))))
except:
pass
data[word.lower()].append(PhonePhrase.from_string(pronunciation.strip().lower()))
return cls(data)
Loading

0 comments on commit 55f8b92

Please sign in to comment.