Skip to content

Commit 2644ffc

Browse files
committed
Test and Score: Sort numerically, not alphabetically
1 parent 6e7e534 commit 2644ffc

4 files changed

Lines changed: 77 additions & 10 deletions

File tree

Orange/widgets/evaluate/owtestlearners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def update_stats_model(self):
533533
for stat, scorer in zip(stats, self.scorers):
534534
item = QStandardItem()
535535
if stat.success:
536-
item.setText("{:.3f}".format(stat.value[0]))
536+
item.setData(float(stat.value[0]), Qt.DisplayRole)
537537
else:
538538
item.setToolTip(str(stat.exception))
539539
if scorer.name in self.score_table.shown_scores:

Orange/widgets/evaluate/tests/test_owtestlearners.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ def __call__(self, data):
277277
# Ensure that the click on header caused an ascending sort
278278
# Ascending sort means that wrong model should be listed first
279279
self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder)
280-
self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner")
280+
self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner")
281281

282282
self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000)
283-
self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner")
283+
self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner")
284284

285285
self.widget.hide()
286286

@@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self):
365365
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
366366
)
367367

368-
self.assertTupleEqual(self._test_scores(
369-
table_train, table_test, LogisticRegressionLearner(),
370-
OWTestLearners.TestOnTest, None),
371-
(0.667, 0.8, 0.8, 0.867, 0.8))
368+
np.testing.assert_almost_equal(
369+
self._test_scores(table_train, table_test,
370+
LogisticRegressionLearner(),
371+
OWTestLearners.TestOnTest, None),
372+
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))
372373

373374
def test_scores_cross_validation(self):
374375
"""

Orange/widgets/evaluate/tests/test_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import collections
55

66
from AnyQt.QtWidgets import QMenu
7-
from AnyQt.QtCore import QPoint
7+
from AnyQt.QtGui import QStandardItem
8+
from AnyQt.QtCore import QPoint, Qt
89

910
from Orange.widgets.evaluate.utils import ScoreTable
1011
from Orange.widgets.tests.base import GuiTest
@@ -70,5 +71,48 @@ def test_update_shown_columns(self):
7071
not header.isSectionHidden(i),
7172
msg="error in section {}({})".format(i, name))
7273

74+
def test_sorting(self):
75+
def order(n=5):
76+
return "".join(model.index(i, 0).data() for i in range(n))
77+
78+
score_table = ScoreTable(None)
79+
80+
data = [
81+
["D", 11.0, 15.3],
82+
["C", 5.0, -15.4],
83+
["b", 20.0, None],
84+
["A", None, None],
85+
["E", "", 0.0]
86+
]
87+
for data_row in data:
88+
row = []
89+
for x in data_row:
90+
item = QStandardItem()
91+
if x is not None:
92+
item.setData(x, Qt.DisplayRole)
93+
row.append(item)
94+
score_table.model.appendRow(row)
95+
96+
model = score_table.view.model()
97+
98+
model.sort(0, Qt.AscendingOrder)
99+
self.assertEqual(order(), "AbCDE")
100+
101+
model.sort(0, Qt.DescendingOrder)
102+
self.assertEqual(order(), "EDCbA")
103+
104+
model.sort(1, Qt.AscendingOrder)
105+
self.assertEqual(order(3), "CDb")
106+
107+
model.sort(1, Qt.DescendingOrder)
108+
self.assertEqual(order(3), "bDC")
109+
110+
model.sort(2, Qt.AscendingOrder)
111+
self.assertEqual(order(3), "CED")
112+
113+
model.sort(2, Qt.DescendingOrder)
114+
self.assertEqual(order(3), "DEC")
115+
116+
73117
if __name__ == "__main__":
74118
unittest.main()

Orange/widgets/evaluate/utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu
88
from AnyQt.QtGui import QStandardItemModel, QStandardItem
9-
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal
9+
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \
10+
QSortFilterProxyModel
1011
from sklearn.exceptions import UndefinedMetricWarning
1112

1213
from Orange.data import Variable, DiscreteVariable, ContinuousVariable
@@ -98,6 +99,19 @@ def thunked():
9899
return thunked
99100

100101

102+
class ScoreModel(QSortFilterProxyModel):
103+
def lessThan(self, left, right):
104+
left = left.data()
105+
right = right.data()
106+
if type(left) is not type(right) or left is None or right is None:
107+
# put the one which is not a number (= an error) at the bottom
108+
return isinstance(left, float) == (
109+
self.sortOrder() == Qt.AscendingOrder)
110+
if isinstance(left, str):
111+
return left.upper() < right.upper()
112+
return left < right
113+
114+
101115
class ScoreTable(OWComponent, QObject):
102116
shown_scores = \
103117
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
@@ -109,6 +123,12 @@ def sizeHint(self, *args):
109123
size = super().sizeHint(*args)
110124
return QSize(size.width(), size.height() + 6)
111125

126+
def displayText(self, value, locale):
127+
if isinstance(value, float):
128+
return f"{value:.3f}"
129+
else:
130+
return super().displayText(value, locale)
131+
112132
def __init__(self, master):
113133
QObject.__init__(self)
114134
OWComponent.__init__(self, master)
@@ -125,7 +145,9 @@ def __init__(self, master):
125145

126146
self.model = QStandardItemModel(master)
127147
self.model.setHorizontalHeaderLabels(["Method"])
128-
self.view.setModel(self.model)
148+
self.sorted_model = ScoreModel()
149+
self.sorted_model.setSourceModel(self.model)
150+
self.view.setModel(self.sorted_model)
129151
self.view.setItemDelegate(self.ItemDelegate())
130152

131153
def _column_names(self):

0 commit comments

Comments
 (0)