-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_grimm_bert.py
106 lines (90 loc) · 4.82 KB
/
test_grimm_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from argparse import ArgumentParser, ArgumentError
from io import StringIO
from unittest import TestCase, main
from unittest.mock import patch
import grimm_bert as gb
from clustering.linkage_name import LinkageName
from clustering.metric_name import MetricName
from data.corpus_name import CorpusName
class TestGrimmBertArgumentParser(TestCase):
@classmethod
def setUpClass(cls):
cls.parser = gb.build_argument_parser()
def test_build_argument_parser(self):
self.assertIsInstance(self.parser, ArgumentParser)
def test_parse_short_options(self):
args = ['exp_name', CorpusName.TOY.value, MetricName.COSINE.value,
LinkageName.AVERAGE.value, '-r', 'R', '-l', 'L', '-m', 'M',
'-c', 'C', '-d', '.5', '-k', '-e']
parsed_args = self.parser.parse_args(args)
self.assertEqual(parsed_args.experiment_name, 'exp_name')
self.assertEqual(parsed_args.corpus_name, CorpusName.TOY)
self.assertEqual(parsed_args.affinity_name, MetricName.COSINE)
self.assertEqual(parsed_args.linkage_name, LinkageName.AVERAGE)
self.assertEqual(parsed_args.corpus_cache, 'C')
self.assertEqual(parsed_args.model_cache, 'M')
self.assertEqual(parsed_args.results_path, 'R')
self.assertEqual(parsed_args.max_distance, 0.5)
self.assertEqual(parsed_args.log, 'L')
self.assertTrue(parsed_args.known_senses)
self.assertTrue(parsed_args.export_html)
def test_parse_long_options(self):
args = ['exp_name', CorpusName.TOY.value, MetricName.COSINE.value,
LinkageName.COMPLETE.value,
'--results_path', 'rp', '--log', 'INFO', '--model_cache',
'md', '--corpus_cache', 'cd', '--max_distance', '0.5',
'--known_senses', '--export_html']
parsed_args = self.parser.parse_args(args)
self.assertEqual(parsed_args.experiment_name, 'exp_name')
self.assertEqual(parsed_args.corpus_name, CorpusName.TOY)
self.assertEqual(parsed_args.affinity_name, MetricName.COSINE)
self.assertEqual(parsed_args.linkage_name, LinkageName.COMPLETE)
self.assertEqual(parsed_args.corpus_cache, 'cd')
self.assertEqual(parsed_args.model_cache, 'md')
self.assertEqual(parsed_args.results_path, 'rp')
self.assertEqual(parsed_args.max_distance, 0.5)
self.assertEqual(parsed_args.log, 'INFO')
self.assertTrue(parsed_args.known_senses)
self.assertTrue(parsed_args.export_html)
def test_parse_defaults(self):
args = ['exp_name', CorpusName.TOY.value, MetricName.COSINE.value,
LinkageName.SINGLE.value]
parsed_args = self.parser.parse_args(args)
self.assertEqual(parsed_args.experiment_name, 'exp_name')
self.assertEqual(parsed_args.corpus_name, CorpusName.TOY)
self.assertEqual(parsed_args.affinity_name, MetricName.COSINE)
self.assertEqual(parsed_args.linkage_name, LinkageName.SINGLE)
self.assertEqual(parsed_args.corpus_cache, gb.DEFAULT_CORPUS_CACHE_DIR)
self.assertEqual(parsed_args.model_cache, gb.DEFAULT_MODEL_CACHE_PATH)
self.assertEqual(parsed_args.results_path, gb.DEFAULT_RESULTS_PATH)
self.assertIsNone(parsed_args.max_distance)
self.assertEqual(parsed_args.log, gb.DEFAULT_LOG_LEVEL)
self.assertFalse(parsed_args.known_senses)
self.assertFalse(parsed_args.export_html)
@patch('sys.stderr', new_callable=StringIO)
def test_parse_no_max_dist(self, mock_stderr):
""" Should raise an ArgumentError on empty max_distance argument. """
with self.assertRaises(ArgumentError) and self.assertRaises(SystemExit):
self.parser.parse_args([
'exp_name', CorpusName.TOY.value, MetricName.COSINE.value,
LinkageName.COMPLETE.value, '--max_distance'])
self.assertRegexpMatches(mock_stderr.getvalue(),
r"expected one argument")
def test_is_max_dist_defined_true(self):
self.assertTrue(gb.is_max_dist_defined(0.2))
def test_is_max_dist_defined_too_small(self):
self.assertFalse(gb.is_max_dist_defined(0.0))
self.assertFalse(gb.is_max_dist_defined(-0.1))
def test_is_max_dist_defined_not_defined(self):
self.assertFalse(gb.is_max_dist_defined(None))
def test_is_min_silhouette_defined(self):
self.assertTrue(gb.is_min_silhouette_defined(0.0))
self.assertTrue(gb.is_min_silhouette_defined(0.5))
self.assertTrue(gb.is_min_silhouette_defined(1.0))
def test_is_max_dist_defined_invalid(self):
self.assertFalse(gb.is_min_silhouette_defined(1.1))
self.assertFalse(gb.is_max_dist_defined(-0.1))
def test_is_min_silhouette_defined_not_defined(self):
self.assertFalse(gb.is_min_silhouette_defined(None))
if __name__ == '__main__':
main()