Skip to content

Commit 6ef5f7d

Browse files
authored
Merge pull request #4959 from janezd/distributions-sort
[ENH] Distributions: Add sorting by category size
2 parents e812868 + 8a48cbd commit 6ef5f7d

2 files changed

Lines changed: 99 additions & 18 deletions

File tree

Orange/widgets/visualize/owdistributions.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class Warning(OWWidget.Warning):
280280
show_probs = settings.Setting(False)
281281
stacked_columns = settings.Setting(False)
282282
cumulative_distr = settings.Setting(False)
283+
sort_by_freq = settings.Setting(False)
283284
kde_smoothing = settings.Setting(10)
284285

285286
auto_apply = settings.Setting(True)
@@ -314,11 +315,14 @@ def __init__(self):
314315
self.key_operation = None
315316
self._user_var_bins = {}
316317

317-
gui.listView(
318+
varview = gui.listView(
318319
self.controlArea, self, "var", box="Variable",
319320
model=DomainModel(valid_types=DomainModel.PRIMITIVE,
320321
separators=False),
321322
callback=self._on_var_changed)
323+
gui.checkBox(
324+
varview.box, self, "sort_by_freq", "Sort categories by frequency",
325+
callback=self._on_sort_by_freq, stateWhenDisabled=False)
322326

323327
box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
324328
slider = gui.hSlider(
@@ -466,6 +470,10 @@ def _on_show_cumulative(self):
466470
self.replot()
467471
self.apply()
468472

473+
def _on_sort_by_freq(self):
474+
self.replot()
475+
self.apply()
476+
469477
def _on_bins_changed(self):
470478
self.reset_select()
471479
self._set_bin_width_slider_label()
@@ -581,6 +589,7 @@ def _set_axis_names(self):
581589

582590
def _update_controls_state(self):
583591
assert self.is_valid # called only from replot, so assumes data is OK
592+
self.controls.sort_by_freq.setDisabled(self.var.is_continuous)
584593
self.continuous_box.setDisabled(self.var.is_discrete)
585594
self.controls.show_probs.setDisabled(self.cvar is None)
586595
self.controls.stacked_columns.setDisabled(self.cvar is None)
@@ -610,11 +619,18 @@ def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded,
610619

611620
def _disc_plot(self):
612621
var = self.var
613-
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
614-
colors = [QColor(0, 128, 255)]
615622
dist = distribution.get_distribution(self.data, self.var)
616-
for i, freq in enumerate(dist):
617-
desc = var.values[i]
623+
dist = np.array(dist) # Distribution misbehaves in further operations
624+
if self.sort_by_freq:
625+
order = np.argsort(dist)[::-1]
626+
else:
627+
order = np.arange(len(dist))
628+
629+
ordered_values = np.array(var.values)[order]
630+
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
631+
632+
colors = [QColor(0, 128, 255)]
633+
for i, freq, desc in zip(count(), dist[order], ordered_values):
618634
tooltip = \
619635
"<p style='white-space:pre;'>" \
620636
f"<b>{escape(desc)}</b>: {int(freq)} " \
@@ -625,13 +641,20 @@ def _disc_plot(self):
625641

626642
def _disc_split_plot(self):
627643
var = self.var
628-
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
644+
conts = contingency.get_contingency(self.data, self.cvar, self.var)
645+
conts = np.array(conts) # Contingency misbehaves in further operations
646+
if self.sort_by_freq:
647+
order = np.argsort(conts.sum(axis=1))[::-1]
648+
else:
649+
order = np.arange(len(conts))
650+
651+
ordered_values = np.array(var.values)[order]
652+
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
653+
629654
gcolors = [QColor(*col) for col in self.cvar.colors]
630655
gvalues = self.cvar.values
631-
conts = contingency.get_contingency(self.data, self.cvar, self.var)
632656
total = len(self.data)
633-
for i, freqs in enumerate(conts):
634-
desc = var.values[i]
657+
for i, freqs, desc in zip(count(), conts[order], ordered_values):
635658
self._add_bar(
636659
i - 0.5, 1, 0.1, freqs, gcolors,
637660
stacked=self.stacked_columns, expanded=self.show_probs,

Orange/widgets/visualize/tests/test_owdistributions.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -376,35 +376,40 @@ def test_controls_disabling(self):
376376
cont = self.iris.domain[0]
377377
disc = self.iris.domain.class_var
378378
cont_box = widget.continuous_box
379+
sort_by_freq = widget.controls.sort_by_freq
379380
show_probs = widget.controls.show_probs
380381
stacked = widget.controls.stacked_columns
381382

382383
self._set_var(cont)
383384
self._set_cvar(disc)
385+
self.assertFalse(sort_by_freq.isEnabled())
384386
self.assertTrue(cont_box.isEnabled())
385387
self.assertTrue(show_probs.isEnabled())
386388
self.assertTrue(stacked.isEnabled())
387389

388390
self._set_var(cont)
389391
self._set_cvar(None)
392+
self.assertFalse(sort_by_freq.isEnabled())
390393
self.assertTrue(cont_box.isEnabled())
391394
self.assertFalse(show_probs.isEnabled())
392395
self.assertFalse(stacked.isEnabled())
393396

394397
self._set_var(disc)
395398
self._set_cvar(None)
399+
self.assertTrue(sort_by_freq.isEnabled())
396400
self.assertFalse(cont_box.isEnabled())
397401
self.assertFalse(show_probs.isEnabled())
398402
self.assertFalse(stacked.isEnabled())
399403

400404
self._set_var(disc)
401405
self._set_cvar(disc)
406+
self.assertTrue(sort_by_freq.isEnabled())
402407
self.assertFalse(cont_box.isEnabled())
403408
self.assertTrue(show_probs.isEnabled())
404409
self.assertTrue(stacked.isEnabled())
405410

406411
if os.getenv("CI"):
407-
# Testing all combinations takes 10-15 seconds; this should take < 2s
412+
# Testing all combinations takes almost a minute; this should take < 2s
408413
# Code for fitter, stacked_columns and show_probs is independent, so
409414
# changing them simultaneously doesn't significantly degrade the tests
410415
def test_plot_types_combinations(self):
@@ -424,6 +429,7 @@ def test_plot_types_combinations(self):
424429
self._set_fitter(2 * b)
425430
self._set_check(c.stacked_columns, b)
426431
self._set_check(c.show_probs, b)
432+
self._set_check(c.sort_by_freq, b)
427433
qApp.processEvents()
428434
else:
429435
def test_plot_types_combinations(self):
@@ -433,6 +439,7 @@ def test_plot_types_combinations(self):
433439

434440
widget = self.widget
435441
c = widget.controls
442+
set_chk = self._set_check
436443
self.send_signal(widget.Inputs.data, self.iris)
437444
cont = self.iris.domain[0]
438445
disc = self.iris.domain.class_var
@@ -442,14 +449,15 @@ def test_plot_types_combinations(self):
442449
for cumulative in [False, True]:
443450
for stack in [False, True]:
444451
for show_probs in [False, True]:
445-
self._set_var(var)
446-
self._set_cvar(cvar)
447-
self._set_fitter(fitter)
448-
self._set_check(c.cumulative_distr,
449-
cumulative)
450-
self._set_check(c.stacked_columns, stack)
451-
self._set_check(c.show_probs, show_probs)
452-
qApp.processEvents()
452+
for sort_by_freq in [False, True]:
453+
self._set_var(var)
454+
self._set_cvar(cvar)
455+
self._set_fitter(fitter)
456+
set_chk(c.cumulative_distr, cumulative)
457+
set_chk(c.stacked_columns, stack)
458+
set_chk(c.show_probs, show_probs)
459+
set_chk(c.sort_by_freq, sort_by_freq)
460+
qApp.processEvents()
453461

454462
def test_selection_grouping(self):
455463
"""Widget groups consecutive selected bars"""
@@ -543,6 +551,56 @@ def test_summary(self):
543551
self.assertEqual(info._StateInfo__output_summary.brief, "")
544552
self.assertEqual(info._StateInfo__output_summary.details, no_output)
545553

554+
def test_sort_by_freq_no_split(self):
555+
data = Table("heart_disease")
556+
domain = data.domain
557+
sort_by_freq = self.widget.controls.sort_by_freq
558+
559+
self.send_signal(self.widget.Inputs.data, data)
560+
self._set_var(domain["gender"])
561+
self._set_cvar(None)
562+
563+
self._set_check(sort_by_freq, False)
564+
out = self.get_output(self.widget.Outputs.histogram_data)
565+
self.assertEqual(out[0][0], "female")
566+
self.assertEqual(out[0][1], 97)
567+
self.assertEqual(out[1][0], "male")
568+
self.assertEqual(out[1][1], 206)
569+
570+
self._set_check(sort_by_freq, True)
571+
out = self.get_output(self.widget.Outputs.histogram_data)
572+
self.assertEqual(out[0][0], "male")
573+
self.assertEqual(out[0][1], 206)
574+
self.assertEqual(out[1][0], "female")
575+
self.assertEqual(out[1][1], 97)
576+
577+
def test_sort_by_freq_split(self):
578+
data = Table("heart_disease")
579+
domain = data.domain
580+
sort_by_freq = self.widget.controls.sort_by_freq
581+
582+
self.send_signal(self.widget.Inputs.data, data)
583+
self._set_var(domain["gender"])
584+
self._set_cvar(domain["rest ECG"])
585+
586+
self._set_check(sort_by_freq, False)
587+
out = self.get_output(self.widget.Outputs.histogram_data)
588+
self.assertEqual(out[0][0], "female")
589+
self.assertEqual(out[0][1], "normal")
590+
self.assertEqual(out[0][2], 49)
591+
self.assertEqual(out[4][0], "male")
592+
self.assertEqual(out[4][1], "left vent hypertrophy")
593+
self.assertEqual(out[4][2], 103)
594+
595+
self._set_check(sort_by_freq, True)
596+
out = self.get_output(self.widget.Outputs.histogram_data)
597+
self.assertEqual(out[0][0], "male")
598+
self.assertEqual(out[0][1], "normal")
599+
self.assertEqual(out[0][2], 102)
600+
self.assertEqual(out[4][0], "female")
601+
self.assertEqual(out[4][1], "left vent hypertrophy")
602+
self.assertEqual(out[4][2], 45)
603+
546604

547605
if __name__ == "__main__":
548606
unittest.main()

0 commit comments

Comments
 (0)