From 3dca0dfc621837f6a0260a44218ac524b55b2193 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 12 Mar 2024 11:28:05 +0100 Subject: [PATCH 1/2] updated onmt/transforms/insert_mask_before_placeholder.py" --- .../insert_mask_before_placeholder.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/onmt/transforms/insert_mask_before_placeholder.py b/onmt/transforms/insert_mask_before_placeholder.py index 0ea390c4f8..8f11626b3a 100755 --- a/onmt/transforms/insert_mask_before_placeholder.py +++ b/onmt/transforms/insert_mask_before_placeholder.py @@ -22,24 +22,28 @@ def add_options(cls, parser): "Transform/InsertMaskBeforePlaceholdersTransform" ) group.add( - "--response_pattern", - "-response_pattern", - type=str, + "--response_patterns", + "-response_patterns", help="Response patten to locate the end of the prompt", - default="Response : ⦅newline⦆", + default=["Response : ⦅newline⦆"], + nargs="+", ) def _parse_opts(self): - self.response_pattern = self.opts.response_pattern + self.response_patterns = self.opts.response_patterns def apply(self, example, is_train=False, stats=None, **kwargs): _src = " ".join(example["src"]) - if len(_src.split(self.response_pattern)) != 2: + response = None + for _pattern in self.response_patterns: + if len(_src.split(_pattern)) == 2: + prompt, response = _src.split(_pattern) + response = DefaultTokens.MASK_BEFORE.join([_pattern, response]) + if response is not None: + _src = "".join([prompt, response]) + example["src"] = _src.split(" ") + example["tgt"] = _src.split(" ") + else: logger.info("The mask_before could not be inserted") return example - prompt, response = _src.split(self.response_pattern) - response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response]) - _src = "".join([prompt, response]) - example["src"] = _src.split(" ") - example["tgt"] = _src.split(" ") return example From eacda99bedd7370cd69bcb669611fc7ecc90c928 Mon Sep 17 00:00:00 2001 From: l-k-11235 Date: Tue, 12 Mar 2024 11:41:42 +0100 Subject: [PATCH 2/2] updated onmt/tests/test_transform.py --- onmt/tests/test_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/tests/test_transform.py b/onmt/tests/test_transform.py index 4eef58d15f..a5c5cc6f70 100644 --- a/onmt/tests/test_transform.py +++ b/onmt/tests/test_transform.py @@ -788,7 +788,7 @@ class TestInsertMaskBeforePlaceholder(unittest.TestCase): @classmethod def setUpClass(cls): cls.base_opts = { - "response_pattern": "Response : ⦅newline⦆", + "response_patterns": ["Response : ⦅newline⦆"], } def test_insert_mask_before_placeholder(self):