Skip to content

Commit 075cdff

Browse files
authored
Support the ability to pass in None for both get_column_plot and get_column_pair_plot (#2344)
1 parent 6b47f9d commit 075cdff

File tree

5 files changed

+203
-39
lines changed

5 files changed

+203
-39
lines changed

sdv/evaluation/_utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pandas as pd
2+
3+
4+
def _prepare_data_visualization(data, metadata, column_names, sample_size):
5+
"""Prepare the data for a plot.
6+
7+
Args:
8+
data (pd.DataFrame or None):
9+
The data to be prepared.
10+
metadata (Metadata):
11+
The metadata of the data.
12+
column_names (str or list[str]):
13+
The column names to plot.
14+
sample_size (int or None):
15+
The number of samples to plot. If ``None``, use the whole dataset.
16+
17+
Returns:
18+
pd.DataFrame or None:
19+
The prepared data.
20+
"""
21+
if data is None:
22+
return None
23+
24+
col_names = column_names if isinstance(column_names, list) else [column_names]
25+
data = data.copy()
26+
for column_name in col_names:
27+
sdtype = metadata.columns[column_name]['sdtype']
28+
if sdtype == 'datetime':
29+
datetime_format = metadata.columns[column_name].get('datetime_format')
30+
data[column_name] = pd.to_datetime(data[column_name], format=datetime_format)
31+
32+
if sample_size and sample_size < len(data):
33+
data = data.sample(n=sample_size)
34+
35+
return data

sdv/evaluation/multi_table.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def get_column_pair_plot(
9494
"""Get a plot of the real and synthetic data for a given column pair.
9595
9696
Args:
97-
real_data (dict):
97+
real_data (dict or None):
9898
Dictionary containing the real table data.
99-
synthetic_column (dict):
99+
synthetic_column (dict or None):
100100
Dictionary containing the synthetic table data.
101101
metadata (Metadata):
102102
Metadata describing the data.

sdv/evaluation/single_table.py

+9-26
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Methods to compare the real and synthetic data for single-table."""
22

3-
import pandas as pd
43
from sdmetrics import visualization
54
from sdmetrics.reports.single_table.diagnostic_report import DiagnosticReport
65
from sdmetrics.reports.single_table.quality_report import QualityReport
76

87
from sdv.errors import VisualizationUnavailableError
8+
from sdv.evaluation._utils import _prepare_data_visualization
99
from sdv.metadata.metadata import Metadata
1010

1111

@@ -68,9 +68,9 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
6868
"""Get a plot of the real and synthetic data for a given column.
6969
7070
Args:
71-
real_data (pandas.DataFrame):
71+
real_data (pandas.DataFrame or None):
7272
The real table data.
73-
synthetic_data (pandas.DataFrame):
73+
synthetic_data (pandas.DataFrame or None):
7474
The synthetic table data.
7575
metadata (Metadata):
7676
The table metadata.
@@ -103,14 +103,8 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
103103
"'plot_type'."
104104
)
105105

106-
if sdtype == 'datetime':
107-
datetime_format = metadata.columns.get(column_name).get('datetime_format')
108-
real_data = pd.DataFrame({
109-
column_name: pd.to_datetime(real_data[column_name], format=datetime_format)
110-
})
111-
synthetic_data = pd.DataFrame({
112-
column_name: pd.to_datetime(synthetic_data[column_name], format=datetime_format)
113-
})
106+
real_data = _prepare_data_visualization(real_data, metadata, column_name, None)
107+
synthetic_data = _prepare_data_visualization(synthetic_data, metadata, column_name, None)
114108

115109
return visualization.get_column_plot(
116110
real_data, synthetic_data, column_name, plot_type=plot_type
@@ -147,8 +141,6 @@ def get_column_pair_plot(
147141
if isinstance(metadata, Metadata):
148142
metadata = metadata._convert_to_single_table()
149143

150-
real_data = real_data.copy()
151-
synthetic_data = synthetic_data.copy()
152144
if plot_type is None:
153145
plot_type = []
154146
for column_name in column_names:
@@ -169,18 +161,9 @@ def get_column_pair_plot(
169161
else:
170162
plot_type = plot_type.pop()
171163

172-
for column_name in column_names:
173-
sdtype = metadata.columns.get(column_name)['sdtype']
174-
if sdtype == 'datetime':
175-
datetime_format = metadata.columns.get(column_name).get('datetime_format')
176-
real_data[column_name] = pd.to_datetime(real_data[column_name], format=datetime_format)
177-
synthetic_data[column_name] = pd.to_datetime(
178-
synthetic_data[column_name], format=datetime_format
179-
)
180-
181-
require_subsample = sample_size and sample_size < min(len(real_data), len(synthetic_data))
182-
if require_subsample:
183-
real_data = real_data.sample(n=sample_size)
184-
synthetic_data = synthetic_data.sample(n=sample_size)
164+
real_data = _prepare_data_visualization(real_data, metadata, column_names, sample_size)
165+
synthetic_data = _prepare_data_visualization(
166+
synthetic_data, metadata, column_names, sample_size
167+
)
185168

186169
return visualization.get_column_pair_plot(real_data, synthetic_data, column_names, plot_type)

tests/unit/evaluation/test__utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from sdv.evaluation._utils import _prepare_data_visualization
5+
from sdv.metadata import SingleTableMetadata
6+
7+
8+
def test__prepare_data_visualization():
9+
"""Test ``_prepare_data_visualization``."""
10+
# Setup
11+
np.random.seed(0)
12+
metadata = SingleTableMetadata.load_from_dict({
13+
'columns': {
14+
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
15+
'col2': {'sdtype': 'numerical'},
16+
}
17+
})
18+
column_names = ['col1', 'col2']
19+
sample_size = 2
20+
data = pd.DataFrame({
21+
'col1': ['2021-01-01', '2021-02-01', '2021-03-01'],
22+
'col2': [4, 5, 6],
23+
})
24+
25+
# Run
26+
result = _prepare_data_visualization(data, metadata, column_names, sample_size)
27+
28+
# Assert
29+
expected_result = pd.DataFrame(
30+
{
31+
'col1': pd.to_datetime(['2021-03-01', '2021-02-01']),
32+
'col2': [6, 5],
33+
},
34+
index=[2, 1],
35+
)
36+
pd.testing.assert_frame_equal(result, expected_result)

0 commit comments

Comments
 (0)