Skip to content

Commit 4be1c4c

Browse files
committed
Add parameters
1 parent 363d8bd commit 4be1c4c

File tree

2 files changed

+93
-39
lines changed

2 files changed

+93
-39
lines changed

sdv/metadata/metadata.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,18 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
6161
instance._set_metadata_dict(metadata_dict, single_table_name)
6262
return instance
6363

64+
@staticmethod
65+
def _validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys):
66+
if not isinstance(infer_sdtypes, bool):
67+
raise ValueError("'infer_sdtypes' must be a boolean value.")
68+
69+
if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
70+
raise ValueError(
71+
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
72+
)
73+
6474
@classmethod
65-
def detect_from_dataframes(cls, data):
75+
def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'):
6676
"""Detect the metadata for all tables in a dictionary of dataframes.
6777
6878
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``.
@@ -71,23 +81,38 @@ def detect_from_dataframes(cls, data):
7181
Args:
7282
data (dict):
7383
Dictionary of table names to dataframes.
84+
infer_sdtypes (bool):
85+
A boolean describing whether to infer the sdtypes of each column.
86+
If True it infers the sdtypes based on the data.
87+
If False it does not infer the sdtypes and all columns are marked as unknown.
88+
Defaults to True.
89+
infer_keys (str):
90+
A string describing whether to infer the primary and/or foreign keys. Options are:
91+
- 'primary_and_foreign': Infer the primary keys in each table,
92+
and the foreign keys in other tables that refer to them
93+
- 'primary_only': Infer only the primary keys of each table
94+
- None: Do not infer any keys
95+
Defaults to 'primary_and_foreign'.
7496
7597
Returns:
7698
Metadata:
7799
A new metadata object with the sdtypes detected from the data.
78100
"""
79101
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
80102
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')
103+
cls._validate_detect_from_dataframes(infer_sdtypes, infer_keys)
81104

82105
metadata = Metadata()
83106
for table_name, dataframe in data.items():
84-
metadata.detect_table_from_dataframe(table_name, dataframe)
107+
metadata.detect_table_from_dataframe(table_name, dataframe, infer_sdtypes, infer_keys)
108+
109+
if infer_keys == 'primary_and_foreign':
110+
metadata._detect_relationships(data)
85111

86-
metadata._detect_relationships(data)
87112
return metadata
88113

89114
@classmethod
90-
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
115+
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME, infer_sdtypes=True, infer_keys='primary_and_foreign'):
91116
"""Detect the metadata for a DataFrame.
92117
93118
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
@@ -96,13 +121,26 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
96121
Args:
97122
data (pandas.DataFrame):
98123
Dictionary of table names to dataframes.
124+
infer_sdtypes (bool):
125+
A boolean describing whether to infer the sdtypes of each column.
126+
If True it infers the sdtypes based on the data.
127+
If False it does not infer the sdtypes and all columns are marked as unknown.
128+
Defaults to True.
129+
infer_keys (str):
130+
A string describing whether to infer the primary and/or foreign keys. Options are:
131+
- 'primary_and_foreign': Infer the primary keys in each table,
132+
and the foreign keys in other tables that refer to them
133+
- 'primary_only': Infer only the primary keys of each table
134+
- None: Do not infer any keys
135+
Defaults to 'primary_and_foreign'.
99136
100137
Returns:
101138
Metadata:
102139
A new metadata object with the sdtypes detected from the data.
103140
"""
104141
if not isinstance(data, pd.DataFrame):
105142
raise ValueError('The provided data must be a pandas DataFrame object.')
143+
cls._validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys)
106144

107145
metadata = Metadata()
108146
metadata.detect_table_from_dataframe(table_name, data)

sdv/metadata/single_table.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -595,53 +595,68 @@ def _detect_primary_key(self, data):
595595

596596
return None
597597

598-
def _detect_columns(self, data, table_name=None):
598+
def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=True):
599599
"""Detect the columns' sdtypes from the data.
600600
601601
Args:
602602
data (pandas.DataFrame):
603603
The data to be analyzed.
604604
table_name (str):
605605
The name of the table to be analyzed. Defaults to ``None``.
606+
infer_sdtypes (bool):
607+
A boolean describing whether to infer the sdtypes of each column.
608+
If True it infers the sdtypes based on the data.
609+
If False it does not infer the sdtypes and all columns are marked as unknown.
610+
Defaults to True.
611+
infer_keys (str):
612+
A string describing whether to infer the primary and/or foreign keys. Options are:
613+
- 'primary_and_foreign': Infer the primary keys in each table
614+
- 'primary_only': Same as 'primary_and_foreign', infer only the primary keys
615+
- None: Do not infer any keys
616+
Defaults to 'primary_and_foreign'.
606617
"""
607618
old_columns = data.columns
608619
data.columns = data.columns.astype(str)
609620
for field in data:
610-
try:
611-
column_data = data[field]
612-
clean_data = column_data.dropna()
613-
dtype = clean_data.infer_objects().dtype.kind
614-
615-
sdtype = self._detect_pii_column(field)
616-
if sdtype is None:
617-
if dtype in self._DTYPES_TO_SDTYPES:
618-
sdtype = self._DTYPES_TO_SDTYPES[dtype]
619-
elif dtype in ['i', 'f', 'u']:
620-
sdtype = self._determine_sdtype_for_numbers(column_data)
621-
622-
elif dtype == 'O':
623-
sdtype = self._determine_sdtype_for_objects(column_data)
621+
if infer_sdtypes:
622+
try:
623+
column_data = data[field]
624+
clean_data = column_data.dropna()
625+
dtype = clean_data.infer_objects().dtype.kind
624626

627+
sdtype = self._detect_pii_column(field)
625628
if sdtype is None:
626-
table_str = f"table '{table_name}' " if table_name else ''
627-
error_message = (
628-
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
629-
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
630-
" 'bool'."
631-
)
632-
raise InvalidMetadataError(error_message)
633-
634-
except Exception as e:
635-
error_type = type(e).__name__
636-
if error_type == 'InvalidMetadataError':
637-
raise e
638-
639-
table_str = f"table '{table_name}' " if table_name else ''
640-
error_message = (
641-
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
642-
f'data format.\n {error_type}: {e}'
643-
)
644-
raise InvalidMetadataError(error_message) from e
629+
if dtype in self._DTYPES_TO_SDTYPES:
630+
sdtype = self._DTYPES_TO_SDTYPES[dtype]
631+
elif dtype in ['i', 'f', 'u']:
632+
sdtype = self._determine_sdtype_for_numbers(column_data)
633+
634+
elif dtype == 'O':
635+
sdtype = self._determine_sdtype_for_objects(column_data)
636+
637+
if sdtype is None:
638+
table_str = f"table '{table_name}' " if table_name else ''
639+
error_message = (
640+
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
641+
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
642+
" 'bool'."
643+
)
644+
raise InvalidMetadataError(error_message)
645+
646+
except Exception as e:
647+
error_type = type(e).__name__
648+
if error_type == 'InvalidMetadataError':
649+
raise e
650+
651+
table_str = f"table '{table_name}' " if table_name else ''
652+
error_message = (
653+
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
654+
f'data format.\n {error_type}: {e}'
655+
)
656+
raise InvalidMetadataError(error_message) from e
657+
658+
else:
659+
sdtype = 'unknown'
645660

646661
column_dict = {'sdtype': sdtype}
647662
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
@@ -655,7 +670,8 @@ def _detect_columns(self, data, table_name=None):
655670

656671
self.columns[field] = deepcopy(column_dict)
657672

658-
self.primary_key = self._detect_primary_key(data)
673+
if infer_keys:
674+
self.primary_key = self._detect_primary_key(data)
659675
self._updated = True
660676
data.columns = old_columns
661677

0 commit comments

Comments
 (0)