Skip to content

Commit aeea336

Browse files
committed
Add sequence key and sequence index to multitable
1 parent c1acb42 commit aeea336

File tree

2 files changed

+66
-20
lines changed

2 files changed

+66
-20
lines changed

sdv/metadata/multi_table.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,26 @@ def anonymize(self):
972972

973973
return MultiTableMetadata.load_from_dict(anonymized_metadata)
974974

975+
def _get_table_info(self, table_name, show_table_details):
976+
node_info = {}
977+
table_meta = self.tables[table_name]
978+
979+
if show_table_details in ['full', 'summarized']:
980+
node_info['primary_key'] = f'Primary key: {table_meta.primary_key}'
981+
if table_meta.sequence_key:
982+
node_info['sequence_key'] = f'Sequence key: {table_meta.sequence_key}'
983+
if table_meta.sequence_index:
984+
node_info['sequence_index'] = f'Sequence index: {table_meta.sequence_index}'
985+
986+
if show_table_details == 'full':
987+
node_info['columns'] = create_columns_node(table_meta.columns)
988+
elif show_table_details == 'summarized':
989+
node_info['columns'] = create_summarized_columns_node(table_meta.columns)
990+
elif show_table_details is None:
991+
return
992+
993+
return node_info
994+
975995
def visualize(
976996
self, show_table_details='full', show_relationship_labels=True, output_filepath=None
977997
):
@@ -1017,22 +1037,9 @@ def visualize(
10171037

10181038
nodes = {}
10191039
edges = []
1020-
if show_table_details == 'full':
1021-
for table_name, table_meta in self.tables.items():
1022-
nodes[table_name] = {
1023-
'columns': create_columns_node(table_meta.columns),
1024-
'primary_key': f'Primary key: {table_meta.primary_key}',
1025-
}
1026-
1027-
elif show_table_details == 'summarized':
1028-
for table_name, table_meta in self.tables.items():
1029-
nodes[table_name] = {
1030-
'columns': create_summarized_columns_node(table_meta.columns),
1031-
'primary_key': f'Primary key: {table_meta.primary_key}',
1032-
}
10331040

1034-
elif show_table_details is None:
1035-
nodes = {table_name: None for table_name in self.tables}
1041+
for table_name in self.tables.keys():
1042+
nodes[table_name] = self._get_table_info(table_name, show_table_details)
10361043

10371044
for relationship in self.relationships:
10381045
parent = relationship.get('parent_table_name')
@@ -1053,11 +1060,18 @@ def visualize(
10531060
for table, info in nodes.items():
10541061
if show_table_details:
10551062
foreign_keys = r'\l'.join(info.get('foreign_keys', []))
1056-
keys = r'\l'.join([info['primary_key'], foreign_keys])
1057-
if foreign_keys:
1058-
label = rf'{{{table}|{info["columns"]}\l|{keys}\l}}'
1059-
else:
1060-
label = rf'{{{table}|{info["columns"]}\l|{keys}}}'
1063+
keys = r'\l'.join(
1064+
filter(
1065+
bool,
1066+
[
1067+
info.get('primary_key'),
1068+
info.get('sequence_key'),
1069+
info.get('sequence_index'),
1070+
foreign_keys,
1071+
],
1072+
)
1073+
)
1074+
label = rf'{{{table}|{info["columns"]}\l|{keys}\l}}'
10611075

10621076
else:
10631077
label = f'{table}'

tests/unit/metadata/test_multi_table.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3337,3 +3337,35 @@ def test_validate_data_without_dict(self):
33373337
# Run and Assert
33383338
with pytest.raises(InvalidMetadataError, match=error_msg):
33393339
metadata.validate_data(data)
3340+
3341+
@patch('sdv.metadata.multi_table.create_summarized_columns_node')
3342+
@patch('sdv.metadata.multi_table.create_columns_node')
3343+
def test__get_table_info(self, mock_columns_node, mock_summarized_columns_node):
3344+
"""Test that the `_get_table_info` method."""
3345+
# Setup
3346+
mock_columns_node.return_value = 'column'
3347+
mock_summarized_columns_node.return_value = 'column'
3348+
metadata = MultiTableMetadata()
3349+
table = Mock()
3350+
table.primary_key = 'primary_key'
3351+
table.sequence_key = None
3352+
table.sequence_index = None
3353+
table_all = Mock()
3354+
table_all.primary_key = 'primary_key'
3355+
table_all.sequence_index = 'sequence_index'
3356+
table_all.sequence_key = 'sequence_key'
3357+
metadata.tables = {'table': table, 'table_all': table_all}
3358+
3359+
# Run
3360+
result = metadata._get_table_info('table', show_table_details='full')
3361+
result_all = metadata._get_table_info('table_all', show_table_details='summarized')
3362+
result_no_info = metadata._get_table_info('table_all', show_table_details=None)
3363+
3364+
# Assert
3365+
assert result['primary_key'] == 'Primary key: primary_key'
3366+
assert 'sequence_index' not in result
3367+
assert 'sequence_key' not in result
3368+
assert result_all['primary_key'] == 'Primary key: primary_key'
3369+
assert result_all['sequence_key'] == 'Sequence key: sequence_key'
3370+
assert result_all['sequence_index'] == 'Sequence index: sequence_index'
3371+
assert result_no_info is None

0 commit comments

Comments
 (0)