Skip to content

Commit

Permalink
Add parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Feb 6, 2025
1 parent 363d8bd commit 4be1c4c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 39 deletions.
46 changes: 42 additions & 4 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,18 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
instance._set_metadata_dict(metadata_dict, single_table_name)
return instance

@staticmethod
def _validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys):
if not isinstance(infer_sdtypes, bool):
raise ValueError("'infer_sdtypes' must be a boolean value.")

if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
)

@classmethod
def detect_from_dataframes(cls, data):
def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'):
"""Detect the metadata for all tables in a dictionary of dataframes.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``.
Expand All @@ -71,23 +81,38 @@ def detect_from_dataframes(cls, data):
Args:
data (dict):
Dictionary of table names to dataframes.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table,
and the foreign keys in other tables that refer to them
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')
cls._validate_detect_from_dataframes(infer_sdtypes, infer_keys)

metadata = Metadata()
for table_name, dataframe in data.items():
metadata.detect_table_from_dataframe(table_name, dataframe)
metadata.detect_table_from_dataframe(table_name, dataframe, infer_sdtypes, infer_keys)

if infer_keys == 'primary_and_foreign':
metadata._detect_relationships(data)

metadata._detect_relationships(data)
return metadata

@classmethod
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME, infer_sdtypes=True, infer_keys='primary_and_foreign'):
"""Detect the metadata for a DataFrame.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
Expand All @@ -96,13 +121,26 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
Args:
data (pandas.DataFrame):
Dictionary of table names to dataframes.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table,
and the foreign keys in other tables that refer to them
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not isinstance(data, pd.DataFrame):
raise ValueError('The provided data must be a pandas DataFrame object.')
cls._validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys)

metadata = Metadata()
metadata.detect_table_from_dataframe(table_name, data)
Expand Down
86 changes: 51 additions & 35 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,53 +595,68 @@ def _detect_primary_key(self, data):

return None

def _detect_columns(self, data, table_name=None):
def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=True):
"""Detect the columns' sdtypes from the data.
Args:
data (pandas.DataFrame):
The data to be analyzed.
table_name (str):
The name of the table to be analyzed. Defaults to ``None``.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table
- 'primary_only': Same as 'primary_and_foreign', infer only the primary keys
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
for field in data:
try:
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = self._detect_pii_column(field)
if sdtype is None:
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f', 'u']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)
if infer_sdtypes:
try:
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = self._detect_pii_column(field)
if sdtype is None:
table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
" 'bool'."
)
raise InvalidMetadataError(error_message)

except Exception as e:
error_type = type(e).__name__
if error_type == 'InvalidMetadataError':
raise e

table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
f'data format.\n {error_type}: {e}'
)
raise InvalidMetadataError(error_message) from e
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f', 'u']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)

if sdtype is None:
table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
" 'bool'."
)
raise InvalidMetadataError(error_message)

except Exception as e:
error_type = type(e).__name__
if error_type == 'InvalidMetadataError':
raise e

table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
f'data format.\n {error_type}: {e}'
)
raise InvalidMetadataError(error_message) from e

else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
Expand All @@ -655,7 +670,8 @@ def _detect_columns(self, data, table_name=None):

self.columns[field] = deepcopy(column_dict)

self.primary_key = self._detect_primary_key(data)
if infer_keys:
self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns

Expand Down

0 comments on commit 4be1c4c

Please sign in to comment.