From f434710aca04312ef58235d1f58eed6865ec8af0 Mon Sep 17 00:00:00 2001 From: Martin Popel Date: Sat, 23 Mar 2024 13:15:12 +0100 Subject: [PATCH 1/5] `--list -t wmt23 -l cs-uk` should print just cs-uk, not other language pairs When omitting `-l`, `--list` will still print all the language pairs for that test set. Motivation: Originally, `--list` showed just the list of language pairs, so there was no reason to call it with `-l`, but now it lists all the **fields** for a given language pair and it is relatively slow (it has to parse the XML files), so it makes sense to restrict the listing to a single language pair only. --- sacrebleu/sacrebleu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sacrebleu/sacrebleu.py b/sacrebleu/sacrebleu.py index d778e1d..6b7cd9e 100755 --- a/sacrebleu/sacrebleu.py +++ b/sacrebleu/sacrebleu.py @@ -242,8 +242,7 @@ def main(): if args.list: if args.test_set: - langpairs = get_langpairs_for_testset(args.test_set) - for pair in langpairs: + for pair in [args.langpair] if args.langpair else get_langpairs_for_testset(args.test_set): fields = DATASETS[args.test_set].fieldnames(pair) print(f'{pair}: {", ".join(fields)}') else: From 163b5941040510a7c3dfd54698629cdd6317a515 Mon Sep 17 00:00:00 2001 From: Martin Popel Date: Sun, 24 Mar 2024 22:08:14 +0100 Subject: [PATCH 2/5] allow `--detail` and `--subset` to be used also with the new XML test sets --- sacrebleu/dataset/wmt_xml.py | 29 +++++----- sacrebleu/utils.py | 102 ++++++++++++++++++++++++----------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/sacrebleu/dataset/wmt_xml.py b/sacrebleu/dataset/wmt_xml.py index 4f78bcc..1aedf1d 100644 --- a/sacrebleu/dataset/wmt_xml.py +++ b/sacrebleu/dataset/wmt_xml.py @@ -26,10 +26,12 @@ def _unwrap_wmt21_or_later(raw_file): This script is adapted from https://github.com/wmt-conference/wmt-format-tools :param raw_file: The raw xml file to unwrap. - :return: Dictionary which contains the following fields: + :return: Dictionary which contains the following fields + (each a list with values for each sentence): - `src`: The source sentences. - `docid`: ID indicating which document the sentences belong to. - `origlang`: The original language of the document. + - `domain`: Domain of the document. - `ref:{translator}`: The references produced by each translator. - `ref`: An alias for the references from the first translator. """ @@ -60,13 +62,8 @@ def _unwrap_wmt21_or_later(raw_file): systems = defaultdict(list) - src_sent_count, doc_count = 0, 0 + src_sent_count, doc_count, seen_domain = 0, 0, False for doc in tree.getroot().findall(".//doc"): - docid = doc.attrib["id"] - origlang = doc.attrib["origlang"] - # present wmt22++ - domain = doc.attrib.get("domain", None) - # Skip the testsuite if "testsuite" in doc.attrib: continue @@ -104,17 +101,17 @@ def get_sents(doc): src.append(src_sents[seg_id]) for system_name in hyps.keys(): systems[system_name].append(hyps[system_name][seg_id]) - docids.append(docid) - orig_langs.append(origlang) - if domain is not None: - domains.append(domain) + docids.append(doc.attrib["id"]) + orig_langs.append(doc.attrib["origlang"]) + # The "domain" attribute is missing in WMT21 and WMT22 + domains.append(doc.get("domain")) + seen_domain = doc.get("domain") is not None src_sent_count += 1 - data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} - if len(domains): - data["domain"] = domains - - return data + fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} + if seen_domain: + fields["domain"] = domains + return fields def _get_langpair_path(self, langpair): """ diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index 56e6fca..cddc3a7 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -488,7 +488,7 @@ def get_available_testsets_for_langpair(langpair: str) -> List[str]: def get_available_origlangs(test_sets, langpair) -> List[str]: - """Return a list of origlang values in according to the raw SGM files.""" + """Return a list of origlang values according to the raw XML/SGM files.""" if test_sets is None: return [] @@ -496,6 +496,10 @@ def get_available_origlangs(test_sets, langpair) -> List[str]: for test_set in test_sets.split(','): dataset = DATASETS[test_set] rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) + from .dataset.wmt_xml import WMTXMLDataset + if isinstance(dataset, WMTXMLDataset): + for origlang in dataset._unwrap_wmt21_or_later(rawfile)['origlang']: + origlangs.add(origlang) if rawfile.endswith('.sgm'): with smart_open(rawfile) as fin: for line in fin: @@ -505,6 +509,25 @@ def get_available_origlangs(test_sets, langpair) -> List[str]: return sorted(list(origlangs)) +def get_available_subsets(test_sets, langpair) -> List[str]: + """Return a list of domain values according to the raw XML files and domain/country values from the SGM files.""" + if test_sets is None: + return [] + + subsets = set() + for test_set in test_sets.split(','): + dataset = DATASETS[test_set] + from .dataset.wmt_xml import WMTXMLDataset + if isinstance(dataset, WMTXMLDataset): + rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) + fields = dataset._unwrap_wmt21_or_later(rawfile) + if 'domain' in fields: + subsets |= set(fields['domain']) + elif test_set in SUBSETS: + subsets |= set("country:" + v.split("-")[0] for v in SUBSETS[test_set].values()) + subsets |= set(v.split("-")[1] for v in SUBSETS[test_set].values()) + return sorted(list(subsets)) + def filter_subset(systems, test_sets, langpair, origlang, subset=None): """Filter sentences with a given origlang (or subset) according to the raw SGM files.""" if origlang is None and subset is None: @@ -516,37 +539,49 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): re_id = re.compile(r'.* docid="([^"]+)".*\n') indices_to_keep = [] - for test_set in test_sets.split(','): dataset = DATASETS[test_set] rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) - if not rawfile.endswith('.sgm'): - raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}') - if subset is not None: - if test_set not in SUBSETS: - raise Exception('No subset annotation available for test set ' + test_set) - doc_to_tags = SUBSETS[test_set] - number_sentences_included = 0 - with smart_open(rawfile) as fin: - include_doc = False - for line in fin: - if line.startswith(' Date: Tue, 16 Jul 2024 01:06:32 -0600 Subject: [PATCH 3/5] Fix lint errors --- sacrebleu/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index cddc3a7..0411d91 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -612,7 +612,7 @@ def print_subset_results(metrics, full_system, full_refs, args): key = f'origlang={origlang}' if subset is None: - key += f' domain=ALL' + key += ' domain=ALL' elif subset.startswith('country:'): key += f' country={subset[8:]}' else: @@ -630,4 +630,4 @@ def print_subset_results(metrics, full_system, full_refs, args): print(f'{key}: sentences={n_system:<6} {score.name:<{max_metric_width}} = {score.score:.{w}f}') # import at the end to avoid circular import -from .dataset import DATASETS, SUBSETS, DOMAINS, COUNTRIES # noqa: E402 +from .dataset import DATASETS, SUBSETS # noqa: E402 From 9a28bb7626ec53973709b8852bf7ce4a35ab7fed Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Wed, 17 Jul 2024 18:26:51 -0600 Subject: [PATCH 4/5] Fix CI errors --- sacrebleu/utils.py | 4 ++++ test.sh | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index 0411d91..834f4b5 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -535,6 +535,9 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): if test_sets is None or langpair is None: raise ValueError('Filtering for --origlang or --subset needs a test (-t) and a language pair (-l).') + if subset is not None and subset.startswith('country:'): + subset = subset[8:] + re_origlang = re.compile(r'.* origlang="([^"]+)".*\n') re_id = re.compile(r'.* docid="([^"]+)".*\n') @@ -557,6 +560,7 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): include_doc = False indices_to_keep.append(include_doc) elif rawfile.endswith('.sgm'): + doc_to_tags = {} if subset is not None: if test_set not in SUBSETS: raise Exception('No subset annotation available for test set ' + test_set) diff --git a/test.sh b/test.sh index 1bb5720..cc4c62f 100755 --- a/test.sh +++ b/test.sh @@ -74,16 +74,16 @@ declare -A EXPECTED EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi"]=53.7432 EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi --origlang=en"]=18.9054 EXPECTED["${CMD} -t wmt17 -l en-fi --echo ref | ${CMD} -b -t wmt17/B -l en-fi --detail"]="55.6 -origlang=en : sentences=1502 BLEU = 21.4 -origlang=fi : sentences=1500 BLEU = 100.0" +origlang=en domain=ALL : sentences=1502 BLEU = 21.4 +origlang=fi domain=ALL : sentences=1500 BLEU = 100.0" EXPECTED["${CMD} -t wmt18,wmt19 -l en-de --echo=src | ${CMD} -t wmt18,wmt19 -l en-de -b --detail"]="3.6 -origlang=de : sentences=1498 BLEU = 3.6 -origlang=en : sentences=3497 BLEU = 3.5 +origlang=de domain=ALL : sentences=1498 BLEU = 3.6 +origlang=en domain=ALL : sentences=3497 BLEU = 3.5 +origlang=en domain=business : sentences=241 BLEU = 3.4 origlang=en country=EU : sentences=265 BLEU = 2.5 origlang=en country=GB : sentences=913 BLEU = 3.1 origlang=en country=OTHER : sentences=801 BLEU = 2.5 origlang=en country=US : sentences=1518 BLEU = 4.2 -origlang=en domain=business : sentences=241 BLEU = 3.4 origlang=en domain=crime : sentences=570 BLEU = 3.6 origlang=en domain=entertainment : sentences=322 BLEU = 5.1 origlang=en domain=politics : sentences=959 BLEU = 3.0 From 2a4cdde2df02fd0ee9d1629b43c021a082cba435 Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Tue, 23 Jul 2024 01:31:51 -0600 Subject: [PATCH 5/5] Fix domain field can be None --- sacrebleu/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index 834f4b5..34d286c 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -1,3 +1,4 @@ +import itertools import json import os import re @@ -548,7 +549,8 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): from .dataset.wmt_xml import WMTXMLDataset if isinstance(dataset, WMTXMLDataset): fields = dataset._unwrap_wmt21_or_later(rawfile) - for doc_origlang, doc_domain in zip(fields['origlang'], fields['domain']): + domains = fields['domain'] if 'domain' in fields else itertools.repeat(None) + for doc_origlang, doc_domain in zip(fields['origlang'], domains): if origlang is None: include_doc = True else: