Skip to content

Commit 18fc3c6

Browse files
committed
make release-tag: Merge branch 'main' into stable
2 parents 13d3e2e + 1205c7e commit 18fc3c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1789
-656
lines changed

.github/auto_assign.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Set to true to add assignees to pull requests
2-
addAssignees: true
2+
addAssignees: author

HISTORY.md

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,42 @@
11
# Release Notes
22

3-
### v1.15.0 - 2024-07-11
3+
### v1.16.0 - 2024-08-22
4+
5+
This release enables the `HMASynthesizer` and other utility functions to work with null foreign key values! It also adds an `anonymization` method to the metadata classes. Additionally, it patches a bug that lets SDV work with more Pandas data types.
6+
7+
### New Features
8+
9+
* Add metadata anonymization to public SDV - Issue [#2137](https://github.com/sdv-dev/SDV/issues/2137) by @R-Palazzo
10+
* Switch drop_missing_values in in drop_unknown_references to support null foreign keys by default - Issue [#2076](https://github.com/sdv-dev/SDV/issues/2076) by @R-Palazzo
11+
* Support nullable foreign keys in HMA - Issue [#2063](https://github.com/sdv-dev/SDV/issues/2063) by @rwedge
12+
* Remove input error from base synthesizer class once nullable foreign keys are supported - Issue [#2057](https://github.com/sdv-dev/SDV/issues/2057) by @rwedge
13+
* Support null foreign keys in get_random_subset - Issue [#2056](https://github.com/sdv-dev/SDV/issues/2056) by @R-Palazzo
14+
* Warn the user if they are trying to save an unfit synthesizer - Issue [#1961](https://github.com/sdv-dev/SDV/issues/1961) by @fealho
15+
16+
### Bugs Fixed
17+
18+
* Using FixedCombinations constraint with an integer constraint column causes sampling to fail - Issue [#2183](https://github.com/sdv-dev/SDV/issues/2183) by @R-Palazzo
19+
* Metadata Detection Fails with new Data Type - Issue [#2182](https://github.com/sdv-dev/SDV/issues/2182) by @R-Palazzo
20+
* Unable visualize just the real data (or just the synthetic data) in a multi-table setting - Issue [#2160](https://github.com/sdv-dev/SDV/issues/2160) by @R-Palazzo
21+
* [dtypes] Numerical Formatter Fails to Learn Format of New Data Types - Issue [#2156](https://github.com/sdv-dev/SDV/issues/2156) by @R-Palazzo
22+
* Primary keys may not be unique for variable length regexes - Issue [#2116](https://github.com/sdv-dev/SDV/issues/2116) by @amontanez24
23+
* Confusing warning when using GANs that suggests that CUDA isn't being used - Issue [#2052](https://github.com/sdv-dev/SDV/issues/2052) by @fealho
24+
* PAR DiagnosticReport not 1.0 with float categorical columns - Issue [#1910](https://github.com/sdv-dev/SDV/issues/1910) by @lajohn4747
25+
* In `PARSynthesizer` I cannot pass in datetime context (`InvalidDataError` during fitting) - Issue [#1485](https://github.com/sdv-dev/SDV/issues/1485) by @lajohn4747
26+
27+
### Internal
28+
29+
* Enabling sdv logging causes tests to fail locally - Issue [#2162](https://github.com/sdv-dev/SDV/issues/2162) by @amontanez24
30+
* Separate primary key detection functionality - Issue [#2101](https://github.com/sdv-dev/SDV/issues/2101) by @amontanez24
31+
32+
### Maintenance
33+
34+
* [dtypes] Update the NumericalFormatter to use the `learn_rounding_digits` from RDT - Issue [#2164](https://github.com/sdv-dev/SDV/issues/2164) by @R-Palazzo
35+
* Mock every usage of `is_faker_function` to speed up the unit tests - Issue [#2163](https://github.com/sdv-dev/SDV/issues/2163) by @R-Palazzo
36+
* Review docs-related dev dependencies - Issue [#2148](https://github.com/sdv-dev/SDV/issues/2148) by @rwedge
37+
* Cap boto and botocore - Issue [#2123](https://github.com/sdv-dev/SDV/issues/2123) by @lajohn4747
38+
39+
## v1.15.0 - 2024-07-11
440

541
This release adds a new utils function called `get_random_sequence_subset`, that allows users to get a subset of sequential data.
642

latest_requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ graphviz==0.20.3
66
numpy==1.26.4
77
pandas==2.2.2
88
platformdirs==4.2.2
9-
rdt==1.12.1
10-
sdmetrics==0.14.1
11-
tqdm==4.66.4
9+
rdt==1.12.3
10+
sdmetrics==0.15.1
11+
tqdm==4.66.5

pyproject.toml

+4-5
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ license = { text = 'BSL-1.1' }
2121
requires-python = '>=3.8,<3.13'
2222
readme = 'README.md'
2323
dependencies = [
24-
'boto3>=1.28',
25-
'botocore>=1.31',
24+
'boto3>=1.28,<2.0.0',
25+
'botocore>=1.31,<2.0.0',
2626
'cloudpickle>=2.1.0',
2727
'graphviz>=0.13.2',
2828
"numpy>=1.21.0,<2.0.0;python_version<'3.10'",
@@ -35,7 +35,7 @@ dependencies = [
3535
'copulas>=0.11.0',
3636
'ctgan>=0.10.0',
3737
'deepecho>=0.6.0',
38-
'rdt>=1.12.0',
38+
'rdt>=1.12.3',
3939
'sdmetrics>=0.14.0',
4040
'platformdirs>=4.0',
4141
'pyyaml>=6.0.1',
@@ -75,7 +75,6 @@ dev = [
7575

7676
# docs
7777
'docutils>=0.12,<1',
78-
'm2r2>=0.2.5,<1',
7978
'nbsphinx>=0.5.0,<1',
8079
'sphinx_toolbox>=2.5,<4',
8180
'Sphinx>=3,<8',
@@ -133,7 +132,7 @@ namespaces = false
133132
version = {attr = 'sdv.__version__'}
134133

135134
[tool.bumpversion]
136-
current_version = "1.15.0"
135+
current_version = "1.16.0.dev1"
137136
parse = '(?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?'
138137
serialize = [
139138
'{major}.{minor}.{patch}.{release}{candidate}',

sdv/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
__author__ = 'DataCebo, Inc.'
88
__email__ = '[email protected]'
9-
__version__ = '1.15.0'
9+
__version__ = '1.16.0.dev1'
1010

1111

1212
import sys

sdv/_utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010

1111
import pandas as pd
1212
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
13+
from rdt.transformers.utils import _GENERATORS
1314

1415
from sdv import version
1516
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError
1617

18+
try:
19+
from re import _parser as sre_parse
20+
except ImportError:
21+
import sre_parse
22+
1723

1824
def _cast_to_iterable(value):
1925
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
@@ -403,3 +409,33 @@ def generate_synthesizer_id(synthesizer):
403409
synth_version = version.public
404410
unique_id = ''.join(str(uuid.uuid4()).split('-'))
405411
return f'{class_name}_{synth_version}_{unique_id}'
412+
413+
414+
def _get_chars_for_option(option, params):
415+
if option not in _GENERATORS:
416+
raise ValueError(f'REGEX operation: {option} is not supported by SDV.')
417+
418+
if option == sre_parse.MAX_REPEAT:
419+
new_option, new_params = params[2][0] # The value at the second index is the nested option
420+
return _get_chars_for_option(new_option, new_params)
421+
422+
return list(_GENERATORS[option](params, 1)[0])
423+
424+
425+
def get_possible_chars(regex, num_subpatterns=None):
426+
"""Get the list of possible characters a regex can create.
427+
428+
Args:
429+
regex (str):
430+
The regex to parse.
431+
num_subpatterns (int):
432+
The number of sub-patterns from the regex to find characters for.
433+
"""
434+
parsed = sre_parse.parse(regex)
435+
parsed = [p for p in parsed if p[0] != sre_parse.AT]
436+
num_subpatterns = num_subpatterns or len(parsed)
437+
possible_chars = []
438+
for option, params in parsed[:num_subpatterns]:
439+
possible_chars += _get_chars_for_option(option, params)
440+
441+
return possible_chars

sdv/constraints/tabular.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
cast_to_datetime64,
4949
compute_nans_column,
5050
get_datetime_diff,
51+
get_mappable_combination,
5152
logit,
5253
matches_datetime_format,
5354
revert_nans_columns,
@@ -297,9 +298,10 @@ def _fit(self, table_data):
297298
self._combinations_to_uuids = {}
298299
self._uuids_to_combinations = {}
299300
for combination in self._combinations.itertuples(index=False, name=None):
301+
mappable_combination = get_mappable_combination(combination)
300302
uuid_str = str(uuid.uuid4())
301-
self._combinations_to_uuids[combination] = uuid_str
302-
self._uuids_to_combinations[uuid_str] = combination
303+
self._combinations_to_uuids[mappable_combination] = uuid_str
304+
self._uuids_to_combinations[uuid_str] = mappable_combination
303305

304306
def is_valid(self, table_data):
305307
"""Say whether the column values are within the original combinations.
@@ -333,6 +335,7 @@ def _transform(self, table_data):
333335
pandas.DataFrame:
334336
Transformed data.
335337
"""
338+
table_data[self._columns] = table_data[self._columns].replace({np.nan: None})
336339
combinations = table_data[self._columns].itertuples(index=False, name=None)
337340
uuids = map(self._combinations_to_uuids.get, combinations)
338341
table_data[self._joint_column] = list(uuids)

sdv/constraints/utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,20 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=
204204
diff_column = diff_column.astype(np.float64)
205205
diff_column[nan_mask] = np.nan
206206
return diff_column
207+
208+
209+
def get_mappable_combination(combination):
210+
"""Get a mappable combination of values.
211+
212+
This function replaces NaN values with None inside the tuple
213+
to ensure consistent comparisons when using mapping.
214+
215+
Args:
216+
combination (tuple):
217+
A combination of values.
218+
219+
Returns:
220+
tuple:
221+
A mappable combination of values.
222+
"""
223+
return tuple(None if pd.isna(x) else x for x in combination)

sdv/data_processing/data_processor.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
import rdt
1111
from pandas.api.types import is_float_dtype, is_integer_dtype
12+
from pandas.errors import IntCastingNaNError
1213
from rdt.transformers import AnonymizedFaker, get_default_transformers
1314
from rdt.transformers.pii.anonymization import get_anonymized_transformer
1415

@@ -902,18 +903,23 @@ def reverse_transform(self, data, reset_keys=False):
902903
reversed_data[column_name] = column_data[column_data.notna()]
903904
try:
904905
reversed_data[column_name] = reversed_data[column_name].astype(dtype)
905-
except ValueError as e:
906+
except (IntCastingNaNError, ValueError) as e:
907+
message = (
908+
f"The real data in '{column_name}' was stored as '{dtype}' but the "
909+
'synthetic data could not be cast back to this type. If this is a '
910+
'problem, please check your input data and metadata settings.'
911+
)
912+
if isinstance(e, IntCastingNaNError):
913+
LOGGER.debug(message)
914+
continue
915+
916+
# Handle the ValueError case
906917
column_metadata = self.metadata.columns.get(column_name)
907918
sdtype = column_metadata.get('sdtype')
908919
if sdtype not in self._DTYPE_TO_SDTYPE.values():
909-
LOGGER.info(
910-
f"The real data in '{column_name}' was stored as '{dtype}' but the "
911-
'synthetic data could not be cast back to this type. If this is a '
912-
'problem, please check your input data and metadata settings.'
913-
)
920+
LOGGER.info(message)
914921
if column_name in self.formatters:
915922
self.formatters.pop(column_name)
916-
917923
else:
918924
raise ValueError(e)
919925

sdv/data_processing/numerical_formatter.py

+11-37
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import logging
44
import sys
55

6-
import numpy as np
76
import pandas as pd
7+
from rdt.transformers.utils import learn_rounding_digits
88

99
LOGGER = logging.getLogger(__name__)
1010

@@ -51,34 +51,6 @@ def __init__(
5151
self.enforce_min_max_values = enforce_min_max_values
5252
self.computer_representation = computer_representation
5353

54-
@staticmethod
55-
def _learn_rounding_digits(data):
56-
"""Check if data has any decimals."""
57-
name = data.name
58-
data = np.array(data)
59-
roundable_data = data[~(np.isinf(data) | pd.isna(data))]
60-
61-
# Doesn't contain numbers
62-
if len(roundable_data) == 0:
63-
return None
64-
65-
# Doesn't contain decimal digits
66-
if ((roundable_data % 1) == 0).all():
67-
return 0
68-
69-
# Try to round to fewer digits
70-
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
71-
for decimal in range(MAX_DECIMALS + 1):
72-
if (roundable_data == roundable_data.round(decimal)).all():
73-
return decimal
74-
75-
# Can't round, not equal after MAX_DECIMALS digits of precision
76-
LOGGER.info(
77-
f"No rounding scheme detected for column '{name}'."
78-
' Synthetic data will not be rounded.'
79-
)
80-
return None
81-
8254
def learn_format(self, column):
8355
"""Learn the format of a column.
8456
@@ -92,7 +64,7 @@ def learn_format(self, column):
9264
self._max_value = column.max()
9365

9466
if self.enforce_rounding:
95-
self._rounding_digits = self._learn_rounding_digits(column)
67+
self._rounding_digits = learn_rounding_digits(column)
9668

9769
def format_data(self, column):
9870
"""Format a column according to the learned format.
@@ -105,20 +77,22 @@ def format_data(self, column):
10577
numpy.ndarray:
10678
containing the formatted data.
10779
"""
108-
column = column.copy().to_numpy()
80+
column = column.copy()
10981
if self.enforce_min_max_values:
11082
column = column.clip(self._min_value, self._max_value)
111-
elif self.computer_representation != 'Float':
83+
elif not self.computer_representation.startswith('Float'):
11284
min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation]
11385
column = column.clip(min_bound, max_bound)
11486

115-
is_integer = np.dtype(self._dtype).kind == 'i'
87+
is_integer = pd.api.types.is_integer_dtype(self._dtype)
88+
np_integer_with_nans = (
89+
not pd.api.types.is_extension_array_dtype(self._dtype)
90+
and is_integer
91+
and pd.isna(column).any()
92+
)
11693
if self.enforce_rounding and self._rounding_digits is not None:
11794
column = column.round(self._rounding_digits)
11895
elif is_integer:
11996
column = column.round(0)
12097

121-
if pd.isna(column).any() and is_integer:
122-
return column
123-
124-
return column.astype(self._dtype)
98+
return column.astype(self._dtype if not np_integer_with_nans else 'float64')

sdv/evaluation/multi_table.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name
7777
1D marginal distribution plot (i.e. a histogram) of the columns.
7878
"""
7979
metadata = metadata.tables[table_name]
80-
real_data = real_data[table_name]
81-
synthetic_data = synthetic_data[table_name]
80+
real_data = real_data[table_name] if real_data else None
81+
synthetic_data = synthetic_data[table_name] if synthetic_data else None
8282
return single_table_visualization.get_column_plot(
8383
real_data,
8484
synthetic_data,
@@ -118,8 +118,8 @@ def get_column_pair_plot(
118118
2D bivariate distribution plot (i.e. a scatterplot) of the columns.
119119
"""
120120
metadata = metadata.tables[table_name]
121-
real_data = real_data[table_name]
122-
synthetic_data = synthetic_data[table_name]
121+
real_data = real_data[table_name] if real_data else None
122+
synthetic_data = synthetic_data[table_name] if synthetic_data else None
123123
return single_table_visualization.get_column_pair_plot(
124124
real_data, synthetic_data, metadata, column_names, sample_size, plot_type
125125
)

sdv/logging/logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, filename=None):
3333

3434
def format(self, record): # noqa: A003
3535
"""Format the record and write to CSV."""
36-
row = record.msg
36+
row = record.msg.copy()
3737
row['LEVEL'] = record.levelname
3838
self.writer.writerow(row)
3939
data = self.output.getvalue()

0 commit comments

Comments
 (0)