diff --git a/lib/classes/tts_manager.py b/lib/classes/tts_manager.py index 95975d20..95002d68 100644 --- a/lib/classes/tts_manager.py +++ b/lib/classes/tts_manager.py @@ -63,6 +63,7 @@ def __init__(self, session): self._build() def _build(self): + self.params['curent_voice_path'] = None if self.session['tts_engine'] == XTTSv2: if self.session['custom_model'] is not None: self.model_name = os.path.basename(self.session['custom_model']) @@ -176,7 +177,9 @@ def convert_sentence_to_audio(self): else models[self.session['tts_engine']][self.session['fine_tuned']]['voice'] if self.session['fine_tuned'] else models[self.session['tts_engine']]['internal']['voice'] ) - self.params['gpt_cond_latent'], self.params['speaker_embedding'] = self.params['tts'].get_conditioning_latents(audio_path=[self.params['voice_path']]) + if self.params['curent_voice_path'] != self.params['voice_path']: + self.params['curent_voice_path'] = self.params['voice_path'] + self.params['gpt_cond_latent'], self.params['speaker_embedding'] = self.params['tts'].get_conditioning_latents(audio_path=[self.params['voice_path']]) with torch.no_grad(): result = self.params['tts'].inference( text=self.params['sentence'], diff --git a/lib/functions.py b/lib/functions.py index 823719bd..6db00080 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -549,7 +549,8 @@ def filter_chapter(doc, lang, lang_iso1, tts_engine): text = normalize_text(text, lang, lang_iso1, tts_engine) # Create regex pattern from punctuation list to split the phoneme_list escaped_punctuation = re.escape(''.join(punctuation_list)) - punctuation_pattern_split = rf'([^{"".join(escaped_punctuation)}]+|[{escaped_punctuation}])' + #punctuation_pattern_split = rf'([^{"".join(escaped_punctuation)}]+|[{escaped_punctuation}])' + punctuation_pattern_split = rf'(\S.*?[{"".join(escaped_punctuation)}])|\S+' # Split by punctuation marks while keeping the punctuation at the end of each word tmp_list = re.findall(punctuation_pattern_split, text) phoneme_list = [phoneme.strip() for phoneme in tmp_list if phoneme.strip()] @@ -575,7 +576,7 @@ def filter_pattern(doc_identifier): elif re.match(r'^\d+$', segment): return 'numbers' return None - +''' def get_sentences(phoneme_list, max_tokens): sentences = [] current_sentence = "" @@ -602,7 +603,42 @@ def get_sentences(phoneme_list, max_tokens): if current_sentence: sentences.append(current_sentence.strip()) return sentences - +''' + +def get_sentences(phoneme_list, max_tokens): + sentences = [] + current_sentence = "" + current_phoneme_count = 0 + for phoneme in phoneme_list: + part_phoneme_count = len(phoneme.split()) + # Always append to current sentence unless punctuation is hit + if current_phoneme_count + part_phoneme_count > max_tokens: + # Ensure we finalize the sentence at punctuation, not a space + if any(current_sentence.endswith(punc) for punc in punctuation_list): + sentences.append(current_sentence.strip()) + current_sentence = phoneme + current_phoneme_count = part_phoneme_count + else: + # Look back and split at last punctuation instead of splitting randomly + last_punc_index = max( + (current_sentence.rfind(punc) for punc in punctuation_list if punc in current_sentence), + default=-1 + ) + if last_punc_index > -1: + sentences.append(current_sentence[:last_punc_index+1].strip()) # Keep punctuation + current_sentence = current_sentence[last_punc_index+1:].strip() + " " + phoneme + current_phoneme_count = len(current_sentence.split()) + else: + sentences.append(current_sentence.strip()) + current_sentence = phoneme + current_phoneme_count = part_phoneme_count + else: + current_sentence += (" " if current_sentence else "") + phoneme + current_phoneme_count += part_phoneme_count + if current_sentence: + sentences.append(current_sentence.strip()) + return sentences + def get_batch_size(list, session): total_size = 0 print(list) @@ -1596,7 +1632,6 @@ def process_cleanup(state): label='Enable Text Splitting', value=default_xtts_settings['enable_text_splitting'], info='Coqui-tts builtin text splitting. Can help against hallucinations bu can also be worse.', - visible=False ) gr_state = gr.State(value={"hash": None}) @@ -2076,6 +2111,8 @@ def change_gr_fine_tuned_list(selected, id): visible = False if selected == 'internal' and session['tts_engine'] == XTTSv2: visible = visible_gr_group_custom_model + else: + visible = False session['fine_tuned'] = selected return gr.update(visible=visible)