diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 78d1554c77..56a958e2fa 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -788,7 +788,7 @@ class TestInsertMaskBeforePlaceholder(unittest.TestCase): @classmethod def setUpClass(cls): cls.base_opts = { - "response_pattern": "Response : ⦅newline⦆", + "response_patterns": ["Response : ⦅newline⦆"], } def test_insert_mask_before_placeholder(self): diff --git a/onmt/transforms/insert_mask_before_placeholder.py b/onmt/transforms/insert_mask_before_placeholder.py index 0ea390c4f8..8f11626b3a 100755 --- a/onmt/transforms/insert_mask_before_placeholder.py +++ b/onmt/transforms/insert_mask_before_placeholder.py @@ -22,24 +22,28 @@ def add_options(cls, parser): "Transform/InsertMaskBeforePlaceholdersTransform" ) group.add( - "--response_pattern", - "-response_pattern", - type=str, + "--response_patterns", + "-response_patterns", help="Response patten to locate the end of the prompt", - default="Response : ⦅newline⦆", + default=["Response : ⦅newline⦆"], + nargs="+", ) def _parse_opts(self): - self.response_pattern = self.opts.response_pattern + self.response_patterns = self.opts.response_patterns def apply(self, example, is_train=False, stats=None, **kwargs): _src = " ".join(example["src"]) - if len(_src.split(self.response_pattern)) != 2: + response = None + for _pattern in self.response_patterns: + if len(_src.split(_pattern)) == 2: + prompt, response = _src.split(_pattern) + response = DefaultTokens.MASK_BEFORE.join([_pattern, response]) + if response is not None: + _src = "".join([prompt, response]) + example["src"] = _src.split(" ") + example["tgt"] = _src.split(" ") + else: logger.info("The mask_before could not be inserted") return example - prompt, response = _src.split(self.response_pattern) - response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response]) - _src = "".join([prompt, response]) - example["src"] = _src.split(" ") - example["tgt"] = _src.split(" ") return example