1
1
"""Methods to compare the real and synthetic data for single-table."""
2
2
3
- import pandas as pd
4
3
from sdmetrics import visualization
5
4
from sdmetrics .reports .single_table .diagnostic_report import DiagnosticReport
6
5
from sdmetrics .reports .single_table .quality_report import QualityReport
7
6
8
7
from sdv .errors import VisualizationUnavailableError
8
+ from sdv .evaluation ._utils import _prepare_data_visualization
9
9
from sdv .metadata .metadata import Metadata
10
10
11
11
@@ -68,9 +68,9 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
68
68
"""Get a plot of the real and synthetic data for a given column.
69
69
70
70
Args:
71
- real_data (pandas.DataFrame):
71
+ real_data (pandas.DataFrame or None ):
72
72
The real table data.
73
- synthetic_data (pandas.DataFrame):
73
+ synthetic_data (pandas.DataFrame or None ):
74
74
The synthetic table data.
75
75
metadata (Metadata):
76
76
The table metadata.
@@ -103,14 +103,8 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
103
103
"'plot_type'."
104
104
)
105
105
106
- if sdtype == 'datetime' :
107
- datetime_format = metadata .columns .get (column_name ).get ('datetime_format' )
108
- real_data = pd .DataFrame ({
109
- column_name : pd .to_datetime (real_data [column_name ], format = datetime_format )
110
- })
111
- synthetic_data = pd .DataFrame ({
112
- column_name : pd .to_datetime (synthetic_data [column_name ], format = datetime_format )
113
- })
106
+ real_data = _prepare_data_visualization (real_data , metadata , column_name , None )
107
+ synthetic_data = _prepare_data_visualization (synthetic_data , metadata , column_name , None )
114
108
115
109
return visualization .get_column_plot (
116
110
real_data , synthetic_data , column_name , plot_type = plot_type
@@ -147,8 +141,6 @@ def get_column_pair_plot(
147
141
if isinstance (metadata , Metadata ):
148
142
metadata = metadata ._convert_to_single_table ()
149
143
150
- real_data = real_data .copy ()
151
- synthetic_data = synthetic_data .copy ()
152
144
if plot_type is None :
153
145
plot_type = []
154
146
for column_name in column_names :
@@ -169,18 +161,9 @@ def get_column_pair_plot(
169
161
else :
170
162
plot_type = plot_type .pop ()
171
163
172
- for column_name in column_names :
173
- sdtype = metadata .columns .get (column_name )['sdtype' ]
174
- if sdtype == 'datetime' :
175
- datetime_format = metadata .columns .get (column_name ).get ('datetime_format' )
176
- real_data [column_name ] = pd .to_datetime (real_data [column_name ], format = datetime_format )
177
- synthetic_data [column_name ] = pd .to_datetime (
178
- synthetic_data [column_name ], format = datetime_format
179
- )
180
-
181
- require_subsample = sample_size and sample_size < min (len (real_data ), len (synthetic_data ))
182
- if require_subsample :
183
- real_data = real_data .sample (n = sample_size )
184
- synthetic_data = synthetic_data .sample (n = sample_size )
164
+ real_data = _prepare_data_visualization (real_data , metadata , column_names , sample_size )
165
+ synthetic_data = _prepare_data_visualization (
166
+ synthetic_data , metadata , column_names , sample_size
167
+ )
185
168
186
169
return visualization .get_column_pair_plot (real_data , synthetic_data , column_names , plot_type )
0 commit comments