diff --git a/nlpaug/augmenter/word/context_word_embs.py b/nlpaug/augmenter/word/context_word_embs.py index 4aa72a5..6b25431 100755 --- a/nlpaug/augmenter/word/context_word_embs.py +++ b/nlpaug/augmenter/word/context_word_embs.py @@ -275,10 +275,12 @@ def insert(self, data): augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) - if isinstance(data, list): - return augmented_texts + augmented_texts = augmented_texts if isinstance(data, list) else augmented_texts[0] + + if self.include_detail: + return augmented_texts, head_doc.get_change_logs() else: - return augmented_texts[0] + return augmented_texts def substitute(self, data): if not data: @@ -414,10 +416,12 @@ def substitute(self, data): augmented_text += ' ' + tail_text augmented_texts.append(augmented_text) - if isinstance(data, list): - return augmented_texts + augmented_texts = augmented_texts if isinstance(data, list) else augmented_texts[0] + + if self.include_detail: + return augmented_texts, head_doc.get_change_logs() else: - return augmented_texts[0] + return augmented_texts @classmethod def get_model(cls, model_path, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,