Skip to content

Commit 9f5992a

Browse files
Muellersenmegalinter-botlars-reimann
authored
feat: implement violin plots (#900)
Closes #867 ### Summary of Changes This pull request applies the changes, which implement violin plots for tables and columns, while also adding corresponding tests as well as adding a section in data_visualization.ipynb. --------- Co-authored-by: megalinter-bot <[email protected]> Co-authored-by: Lars Reimann <[email protected]>
1 parent 5a0cdb3 commit 9f5992a

23 files changed

+320
-19
lines changed

docs/tutorials/data_visualization.ipynb

Lines changed: 64 additions & 6 deletions
Large diffs are not rendered by default.

src/safeds/data/tabular/plotting/_column_plotter.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def box_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
5656
"""
5757
if self._column.row_count > 0:
5858
_check_column_is_numeric(self._column, operation="create a box plot")
59-
6059
import matplotlib.pyplot as plt
6160

6261
def _set_boxplot_colors(box: dict, theme: str) -> None:
@@ -127,6 +126,73 @@ def _set_boxplot_colors(box: dict, theme: str) -> None:
127126

128127
return _figure_to_image(fig)
129128

129+
def violin_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
130+
"""
131+
Create a violin plot for the values in the column. This is only possible for numeric columns.
132+
133+
Parameters
134+
----------
135+
theme:
136+
The color theme of the plot. Default is "light".
137+
138+
Returns
139+
-------
140+
plot:
141+
The violin plot as an image.
142+
143+
Raises
144+
------
145+
TypeError
146+
If the column is not numeric.
147+
148+
Examples
149+
--------
150+
>>> from safeds.data.tabular.containers import Column
151+
>>> column = Column("test", [1, 2, 3])
152+
>>> violinplot = column.plot.violin_plot()
153+
"""
154+
if self._column.row_count > 0:
155+
_check_column_is_numeric(self._column, operation="create a violin plot")
156+
from math import nan
157+
158+
import matplotlib.pyplot as plt
159+
160+
style = "dark_background" if theme == "dark" else "default"
161+
with plt.style.context(style):
162+
if theme == "dark":
163+
plt.rcParams.update(
164+
{
165+
"text.color": "white",
166+
"axes.labelcolor": "white",
167+
"axes.edgecolor": "white",
168+
"xtick.color": "white",
169+
"ytick.color": "white",
170+
"grid.color": "gray",
171+
"grid.linewidth": 0.5,
172+
},
173+
)
174+
else:
175+
plt.rcParams.update(
176+
{
177+
"grid.linewidth": 0.5,
178+
},
179+
)
180+
181+
fig, ax = plt.subplots()
182+
data = self._column._series.drop_nulls()
183+
if len(data) == 0:
184+
data = [nan, nan]
185+
ax.violinplot(
186+
data,
187+
)
188+
189+
ax.set(title=self._column.name)
190+
191+
ax.yaxis.grid(visible=True)
192+
fig.tight_layout()
193+
194+
return _figure_to_image(fig)
195+
130196
def histogram(self, *, max_bin_count: int = 10, theme: Literal["dark", "light"] = "light") -> Image:
131197
"""
132198
Create a histogram for the values in the column.

src/safeds/data/tabular/plotting/_table_plotter.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,102 @@ def box_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
119119
fig.delaxes(axs[number_of_rows - 1, i])
120120

121121
fig.tight_layout()
122+
return _figure_to_image(fig)
123+
124+
def violin_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
125+
"""
126+
Create a violin plot for every numerical column.
127+
128+
Parameters
129+
----------
130+
theme:
131+
The color theme of the plot. Default is "light".
132+
133+
Returns
134+
-------
135+
plot:
136+
The violin plot(s) as an image.
137+
138+
Raises
139+
------
140+
NonNumericColumnError
141+
If the table contains only non-numerical columns.
142+
143+
Examples
144+
--------
145+
>>> from safeds.data.tabular.containers import Table
146+
>>> table = Table({"a": [1, 2], "b": [3, 42]})
147+
>>> image = table.plot.violin_plots()
148+
"""
149+
numerical_table = self._table.remove_non_numeric_columns()
150+
if numerical_table.column_count == 0:
151+
raise NonNumericColumnError("This table contains only non-numerical columns.")
152+
from math import ceil
153+
154+
import matplotlib.pyplot as plt
155+
156+
style = "dark_background" if theme == "dark" else "default"
157+
with plt.style.context(style):
158+
if theme == "dark":
159+
plt.rcParams.update(
160+
{
161+
"text.color": "white",
162+
"axes.labelcolor": "white",
163+
"axes.edgecolor": "white",
164+
"xtick.color": "white",
165+
"ytick.color": "white",
166+
"grid.color": "gray",
167+
"grid.linewidth": 0.5,
168+
},
169+
)
170+
else:
171+
plt.rcParams.update(
172+
{
173+
"grid.linewidth": 0.5,
174+
},
175+
)
176+
177+
columns = numerical_table.to_columns()
178+
columns = [column._series.drop_nulls() for column in columns]
179+
max_width = 3
180+
number_of_columns = len(columns) if len(columns) <= max_width else max_width
181+
number_of_rows = ceil(len(columns) / number_of_columns)
182+
183+
fig, axs = plt.subplots(nrows=number_of_rows, ncols=number_of_columns)
184+
line = 0
185+
for i, column in enumerate(columns):
186+
data = column.to_list()
187+
188+
if i % number_of_columns == 0 and i != 0:
189+
line += 1
190+
191+
if number_of_columns == 1:
192+
axs.violinplot(
193+
data,
194+
)
195+
axs.set_title(numerical_table.column_names[i])
196+
break
122197

123-
style = "dark_background" if theme == "dark" else "default"
124-
with plt.style.context(style):
125-
if theme == "dark":
126-
plt.rcParams.update(
127-
{
128-
"text.color": "white",
129-
"axes.labelcolor": "white",
130-
"axes.edgecolor": "white",
131-
"xtick.color": "white",
132-
"ytick.color": "white",
133-
},
198+
if number_of_rows == 1:
199+
axs[i].violinplot(
200+
data,
201+
)
202+
axs[i].set_title(numerical_table.column_names[i])
203+
204+
else:
205+
axs[line, i % number_of_columns].violinplot(
206+
data,
134207
)
135-
return _figure_to_image(fig)
208+
axs[line, i % number_of_columns].set_title(numerical_table.column_names[i])
209+
210+
# removes unused ax indices, so there wont be empty plots
211+
last_filled_ax_index = len(columns) % number_of_columns
212+
for i in range(last_filled_ax_index, number_of_columns):
213+
if number_of_rows != 1 and last_filled_ax_index != 0:
214+
fig.delaxes(axs[number_of_rows - 1, i])
215+
216+
fig.tight_layout()
217+
return _figure_to_image(fig)
136218

137219
def correlation_heatmap(self, *, theme: Literal["dark", "light"] = "light") -> Image:
138220
"""
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
from safeds.data.tabular.containers import Column
3+
from safeds.exceptions import ColumnTypeError
4+
from syrupy import SnapshotAssertion
5+
6+
7+
@pytest.mark.parametrize(
8+
"column",
9+
[
10+
Column("a", []),
11+
Column("a", [0]),
12+
Column("a", [0, 1]),
13+
],
14+
ids=[
15+
"empty",
16+
"one row",
17+
"multiple rows",
18+
],
19+
)
20+
def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
21+
violin_plot = column.plot.violin_plot()
22+
assert violin_plot == snapshot_png_image
23+
24+
25+
@pytest.mark.parametrize(
26+
"column",
27+
[
28+
Column("a", []),
29+
Column("a", [0]),
30+
Column("a", [0, 1]),
31+
],
32+
ids=[
33+
"empty",
34+
"one row",
35+
"multiple rows",
36+
],
37+
)
38+
def test_should_match_dark_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
39+
violin_plot = column.plot.violin_plot(theme="dark")
40+
assert violin_plot == snapshot_png_image
41+
42+
43+
def test_should_raise_if_column_contains_non_numerical_values() -> None:
44+
column = Column("a", ["A", "B", "C"])
45+
with pytest.raises(ColumnTypeError):
46+
column.plot.violin_plot()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
from safeds.data.tabular.containers import Table
3+
from safeds.exceptions import NonNumericColumnError
4+
from syrupy import SnapshotAssertion
5+
6+
7+
@pytest.mark.parametrize(
8+
"table",
9+
[
10+
Table({"A": [1, 2, 3]}),
11+
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
12+
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
13+
],
14+
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
15+
)
16+
def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
17+
violinplots = table.plot.violin_plots()
18+
assert violinplots == snapshot_png_image
19+
20+
21+
@pytest.mark.parametrize(
22+
"table",
23+
[
24+
Table({"A": [1, 2, 3]}),
25+
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
26+
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
27+
],
28+
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
29+
)
30+
def test_should_match_dark_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
31+
violinplots = table.plot.violin_plots(theme="dark")
32+
assert violinplots == snapshot_png_image
33+
34+
35+
def test_should_raise_if_column_contains_non_numerical_values() -> None:
36+
table = Table.from_dict({"A": ["1", "2", "3.5"], "B": ["0.2", "4", "77"]})
37+
with pytest.raises(
38+
NonNumericColumnError,
39+
match=(
40+
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis table contains only"
41+
r" non-numerical columns."
42+
),
43+
):
44+
table.plot.violin_plots()
45+
46+
47+
def test_should_fail_on_empty_table() -> None:
48+
with pytest.raises(NonNumericColumnError):
49+
Table().plot.violin_plots()

0 commit comments

Comments
 (0)