Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if all docs have domain attribute #267

Merged
merged 5 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def _unwrap_wmt21_or_later(raw_file):
This script is adapted from https://github.com/wmt-conference/wmt-format-tools

:param raw_file: The raw xml file to unwrap.
:return: Dictionary which contains the following fields:
:return: Dictionary which contains the following fields
(each a list with values for each sentence):
- `src`: The source sentences.
- `docid`: ID indicating which document the sentences belong to.
- `origlang`: The original language of the document.
- `domain`: Domain of the document.
- `ref:{translator}`: The references produced by each translator.
- `ref`: An alias for the references from the first translator.
"""
Expand Down Expand Up @@ -60,13 +62,8 @@ def _unwrap_wmt21_or_later(raw_file):

systems = defaultdict(list)

src_sent_count, doc_count = 0, 0
src_sent_count, doc_count, seen_domain = 0, 0, False
for doc in tree.getroot().findall(".//doc"):
docid = doc.attrib["id"]
origlang = doc.attrib["origlang"]
# present wmt22++
domain = doc.attrib.get("domain", None)

# Skip the testsuite
if "testsuite" in doc.attrib:
continue
Expand Down Expand Up @@ -104,17 +101,17 @@ def get_sents(doc):
src.append(src_sents[seg_id])
for system_name in hyps.keys():
systems[system_name].append(hyps[system_name][seg_id])
docids.append(docid)
orig_langs.append(origlang)
if domain is not None:
domains.append(domain)
docids.append(doc.attrib["id"])
orig_langs.append(doc.attrib["origlang"])
# The "domain" attribute is missing in WMT21 and WMT22
domains.append(doc.get("domain"))
seen_domain = doc.get("domain") is not None
src_sent_count += 1

data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if len(domains):
data["domain"] = domains

return data
fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if seen_domain:
fields["domain"] = domains
return fields

def _get_langpair_path(self, langpair):
"""
Expand Down
3 changes: 1 addition & 2 deletions sacrebleu/sacrebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def main():

if args.list:
if args.test_set:
langpairs = get_langpairs_for_testset(args.test_set)
for pair in langpairs:
for pair in [args.langpair] if args.langpair else get_langpairs_for_testset(args.test_set):
fields = DATASETS[args.test_set].fieldnames(pair)
print(f'{pair}: {", ".join(fields)}')
else:
Expand Down
110 changes: 77 additions & 33 deletions sacrebleu/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import json
import os
import re
Expand Down Expand Up @@ -488,14 +489,18 @@ def get_available_testsets_for_langpair(langpair: str) -> List[str]:


def get_available_origlangs(test_sets, langpair) -> List[str]:
"""Return a list of origlang values in according to the raw SGM files."""
"""Return a list of origlang values according to the raw XML/SGM files."""
if test_sets is None:
return []

origlangs = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
for origlang in dataset._unwrap_wmt21_or_later(rawfile)['origlang']:
origlangs.add(origlang)
if rawfile.endswith('.sgm'):
with smart_open(rawfile) as fin:
for line in fin:
Expand All @@ -505,48 +510,84 @@ def get_available_origlangs(test_sets, langpair) -> List[str]:
return sorted(list(origlangs))


def get_available_subsets(test_sets, langpair) -> List[str]:
"""Return a list of domain values according to the raw XML files and domain/country values from the SGM files."""
if test_sets is None:
return []

subsets = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
fields = dataset._unwrap_wmt21_or_later(rawfile)
if 'domain' in fields:
subsets |= set(fields['domain'])
elif test_set in SUBSETS:
subsets |= set("country:" + v.split("-")[0] for v in SUBSETS[test_set].values())
subsets |= set(v.split("-")[1] for v in SUBSETS[test_set].values())
return sorted(list(subsets))

def filter_subset(systems, test_sets, langpair, origlang, subset=None):
"""Filter sentences with a given origlang (or subset) according to the raw SGM files."""
if origlang is None and subset is None:
return systems
if test_sets is None or langpair is None:
raise ValueError('Filtering for --origlang or --subset needs a test (-t) and a language pair (-l).')

if subset is not None and subset.startswith('country:'):
subset = subset[8:]

re_origlang = re.compile(r'.* origlang="([^"]+)".*\n')
re_id = re.compile(r'.* docid="([^"]+)".*\n')

indices_to_keep = []

for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
if not rawfile.endswith('.sgm'):
raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}')
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
number_sentences_included = 0
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
fields = dataset._unwrap_wmt21_or_later(rawfile)
domains = fields['domain'] if 'domain' in fields else itertools.repeat(None)
for doc_origlang, doc_domain in zip(fields['origlang'], domains):
if origlang is None:
include_doc = True
else:
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
include_doc = doc_origlang == origlang
if subset is not None and (doc_domain is None or not re.search(subset, doc_domain)):
include_doc = False
indices_to_keep.append(include_doc)
elif rawfile.endswith('.sgm'):
doc_to_tags = {}
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
number_sentences_included += 1 if include_doc else 0
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
else:
raise Exception(f'--origlang and --subset supports only WMT *.xml and *.sgm files, not {rawfile!r}')
return [[sentence for sentence, keep in zip(sys, indices_to_keep) if keep] for sys in systems]


Expand All @@ -565,8 +606,9 @@ def print_subset_results(metrics, full_system, full_refs, args):
subsets = [None]
if args.subset is not None:
subsets += [args.subset]
elif all(t in SUBSETS for t in args.test_set.split(',')):
subsets += COUNTRIES + DOMAINS
else:
subsets += get_available_subsets(args.test_set, args.langpair)

for subset in subsets:
system, *refs = filter_subset(
[full_system, *full_refs], args.test_set, args.langpair, origlang, subset)
Expand All @@ -575,9 +617,11 @@ def print_subset_results(metrics, full_system, full_refs, args):
continue

key = f'origlang={origlang}'
if subset in COUNTRIES:
key += f' country={subset}'
elif subset in DOMAINS:
if subset is None:
key += ' domain=ALL'
elif subset.startswith('country:'):
key += f' country={subset[8:]}'
else:
key += f' domain={subset}'

for metric in metrics.values():
Expand All @@ -592,4 +636,4 @@ def print_subset_results(metrics, full_system, full_refs, args):
print(f'{key}: sentences={n_system:<6} {score.name:<{max_metric_width}} = {score.score:.{w}f}')

# import at the end to avoid circular import
from .dataset import DATASETS, SUBSETS, DOMAINS, COUNTRIES # noqa: E402
from .dataset import DATASETS, SUBSETS # noqa: E402
10 changes: 5 additions & 5 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ declare -A EXPECTED
EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi"]=53.7432
EXPECTED["${CMD} -t wmt16,wmt17 -l en-fi --echo ref | ${CMD} -b -w 4 -t wmt16/B,wmt17/B -l en-fi --origlang=en"]=18.9054
EXPECTED["${CMD} -t wmt17 -l en-fi --echo ref | ${CMD} -b -t wmt17/B -l en-fi --detail"]="55.6
origlang=en : sentences=1502 BLEU = 21.4
origlang=fi : sentences=1500 BLEU = 100.0"
origlang=en domain=ALL : sentences=1502 BLEU = 21.4
origlang=fi domain=ALL : sentences=1500 BLEU = 100.0"
EXPECTED["${CMD} -t wmt18,wmt19 -l en-de --echo=src | ${CMD} -t wmt18,wmt19 -l en-de -b --detail"]="3.6
origlang=de : sentences=1498 BLEU = 3.6
origlang=en : sentences=3497 BLEU = 3.5
origlang=de domain=ALL : sentences=1498 BLEU = 3.6
origlang=en domain=ALL : sentences=3497 BLEU = 3.5
origlang=en domain=business : sentences=241 BLEU = 3.4
origlang=en country=EU : sentences=265 BLEU = 2.5
origlang=en country=GB : sentences=913 BLEU = 3.1
origlang=en country=OTHER : sentences=801 BLEU = 2.5
origlang=en country=US : sentences=1518 BLEU = 4.2
origlang=en domain=business : sentences=241 BLEU = 3.4
origlang=en domain=crime : sentences=570 BLEU = 3.6
origlang=en domain=entertainment : sentences=322 BLEU = 5.1
origlang=en domain=politics : sentences=959 BLEU = 3.0
Expand Down
Loading