Skip to content

Commit

Permalink
Refactoring: similarity_runner rework
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivieFranklova committed Nov 18, 2024
1 parent 602d29d commit 4b08f9d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,5 @@ def _compare(self, metadata1: Metadata, metadata2: Metadata) -> SimilarityOutput
else:
result += dist * dist * ratio * weight
if nan == len(distances):
return 1
return SimilarityOutput(distance = 1)
return SimilarityOutput(distance = np.sqrt(result))
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TypeMetadataCreator(MetadataCreator):
# generated by computer/person
# distribution of numerical data
# most common triplets, ... for text data
def __init__(self, input_: MetadataCreatorInput):
def __init__(self):
"""
Constructor of TypeMetadataCreator
:param dataframe: DataFrame from which we will create metadata
Expand All @@ -59,7 +59,7 @@ def __init__(self, input_: MetadataCreatorInput):
column_incomplete - list of booleans, for each column list will contain
True for incomplete data and False otherwise
"""
super().__init__(input_)
super().__init__()
self.model: Optional[SentenceTransformer] = SentenceTransformer("bert-base-nli-mean-tokens", tokenizer_kwargs={"clean_up_tokenization_spaces": True})

def __normalize(self, num1: int, num2: int) -> tuple[int, int]:
Expand Down
27 changes: 18 additions & 9 deletions similarity_framework/src/interfaces/metadata/MetadataCreator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional

import pandas as pd

from similarity_framework.src.models.metadata import Metadata, MetadataCreatorInput

Expand All @@ -11,32 +14,38 @@ class __FunctionsParams:
args: tuple
kwargs: dict

def __hash__(self):
return hash(self.func) + hash(self.args) + hash(str(self.kwargs))

@staticmethod
def buildermethod(func):
def inner1(*args, **kwargs):
if not args[0].create:
args[0].functions_to_run.append(MetadataCreator.__FunctionsParams(func=func, args=args, kwargs=kwargs))
args[0].__dict__['_MetadataCreator__functions_to_run'].add(MetadataCreator.__FunctionsParams(func=func, args=args, kwargs=kwargs))
return args[0]
else:
func(*args, **kwargs)
return args[0]
return inner1

def __init__(self, input_: MetadataCreatorInput):
self.dataframe = input_.dataframe
self.functions_to_run = list()
def __init__(self):
self.dataframe: Optional[pd.DataFrame] = None
self.__functions_to_run = set()
self.create = False
self.metadata = Metadata()
self.metadata: Optional[Metadata] = None

@abstractmethod
def _get_metadata_impl(self):
pass

def get_metadata(self) -> Metadata:
def get_metadata(self, input_: MetadataCreatorInput) -> Metadata:
self.dataframe = input_.dataframe
self.metadata = Metadata()

self._get_metadata_impl()
self.create = True
for fun in self.functions_to_run:
for fun in self.__functions_to_run:
fun.func(*fun.args, **fun.kwargs)
self.functions_to_run = set()
self.functions_to_run = list()
self.__functions_to_run = set()
self.create = False
return self.metadata
4 changes: 2 additions & 2 deletions similarity_runner/src/interfaces/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def run(self):

metadata_input: list[MetadataCreatorInput] = connector.get_data(connector_settings)
metadata = []
# TODO: call specific methods on connector based on analysis settings
TypeMetadataCreator().compute_advanced_structural_types()
for i in metadata_input:
metadata.append(
# TODO: call specific methods on connector based on analysis settings
TypeMetadataCreator(i)
.compute_basic_types()
.get_metadata()
)
Expand Down
44 changes: 22 additions & 22 deletions tests/similarity_framework/test_similarity_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,21 @@ def setUp(self):
self.data_second_half.index = self.data_second_half.index - int(len(self.data) / 2)
self.data_diff_type = self.data.copy() # todo fill

self.metadata_creator = TypeMetadataCreator(
MetadataCreatorInput(dataframe=self.data)
).compute_advanced_structural_types().compute_column_kind()
self.metadata1 = self.metadata_creator.get_metadata()
self.metadata_creator = TypeMetadataCreator().compute_advanced_structural_types().compute_column_kind()
self.metadata1 = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))

metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_diff_column_names)).
metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_diff_column_names = metadata_creator.get_metadata()
metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_first_half)).
self.metadata_diff_column_names = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_diff_column_names))
metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_first_half = metadata_creator.get_metadata()
metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_second_half)).
self.metadata_first_half = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_first_half))
metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_second_half = metadata_creator.get_metadata()
self.metadata_second_half = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_second_half))

def test_size_compare(self):
self.compartor.add_comparator_type(SizeHandler())
Expand Down Expand Up @@ -112,7 +110,7 @@ def test_exact_names_compare(self):

def test_embeddings_names_compare(self):
self.metadata_creator.compute_column_names_embeddings()
self.metadata1 = self.metadata_creator.get_metadata()
self.metadata1 = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))
self.compartor.add_comparator_type(ColumnNamesEmbeddingsHandler())

self.assertEqual(self.compartor.compare(self.metadata1, self.metadata1).distance, 0)
Expand Down Expand Up @@ -170,23 +168,25 @@ def setUp(self):
self.data_second_half.index = self.data_second_half.index - int(len(self.data) / 2)
self.data_diff_type = self.data.copy() # todo fill

self.metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data)).
self.metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata1 = self.metadata_creator.get_metadata()
self.metadata1 = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))

metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_diff_column_names)).
metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_diff_column_names = metadata_creator.get_metadata()
metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_first_half)).
self.metadata_diff_column_names = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_diff_column_names))

metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_first_half = metadata_creator.get_metadata()
metadata_creator = (TypeMetadataCreator(MetadataCreatorInput(dataframe=self.data_second_half)).
self.metadata_first_half = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_first_half))

metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata_second_half = metadata_creator.get_metadata()
self.metadata_second_half = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data_second_half))

def test_size_compare(self):
self.compartor.add_comparator_type(SizeHandlerByColumn())
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_exact_names_compare(self):

def test_embeddings_names_compare(self):
self.metadata_creator.compute_column_names_embeddings()
self.metadata1 = self.metadata_creator.get_metadata()
self.metadata1 = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))
self.compartor.add_comparator_type(ColumnNamesEmbeddingsHandlerByColumn())

self.assertEqual(self.compartor.compare(self.metadata1, self.metadata1).distance, 0)
Expand Down Expand Up @@ -265,8 +265,8 @@ def test_kind_CONSTANT_compare(self):
self.assertEqual(self.compartor.compare(self.metadata_first_half, self.metadata_second_half).distance, 0)

def test_embedding_compare(self):
self.metadata_creator.create_column_embeddings()
self.metadata1 = self.metadata_creator.get_metadata()
self.metadata_creator.create_column_embeddings().compute_advanced_structural_types().compute_column_kind()
self.metadata1 = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))
self.compartor.add_comparator_type(ColumnEmbeddingsHandler())

self.assertEqual(self.compartor.compare(self.metadata1, self.metadata1).distance, 0)
Expand Down
13 changes: 6 additions & 7 deletions tests/similarity_framework/test_similarity_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ def setUp(self):

self.file = os.path.join(THIS_DIR, '../data_validation/edge_cases.csv')
self.data = pd.read_csv(self.file)
self.metadata_creator = (TypeMetadataCreator(
MetadataCreatorInput(dataframe=self.data)
).
self.metadata_creator = (TypeMetadataCreator().
compute_advanced_structural_types().
compute_column_kind())
self.metadata = self.metadata_creator.get_metadata()
self.metadata = self.metadata_creator.get_metadata(MetadataCreatorInput(dataframe=self.data))

def test_get_column(self):
column_names = self.metadata.get_numerical_columns_names()
Expand Down Expand Up @@ -69,8 +67,8 @@ def test_default_fill(self):
'float_with_minus': ['-2.1', '-3.0', '5.0', '2.0']}
str_data = pd.DataFrame(data)
str_data.float_with_nan = str_data.float_with_nan.astype(float)
metadata_creator = TypeMetadataCreator(MetadataCreatorInput(dataframe=str_data))
metadata = metadata_creator.get_metadata()
metadata_creator = TypeMetadataCreator()
metadata = metadata_creator.get_metadata(MetadataCreatorInput(dataframe=str_data))
self.assertEqual(sum(1 for value in metadata.column_incomplete.values() if value),
1) # any incomplete column

Expand Down Expand Up @@ -173,7 +171,8 @@ def test_embeddings(self):
self.assertIsNot(self.metadata.column_names_clean.values(), {})
self.assertIsNot(self.metadata.column_names, {})

metadata = self.metadata_creator.compute_column_names_embeddings().create_column_embeddings().get_metadata()
metadata = (self.metadata_creator.compute_column_names_embeddings()
.create_column_embeddings().get_metadata(MetadataCreatorInput(dataframe=self.data)))

self.assertIsNot({}, metadata.column_embeddings)
self.assertIsNot({}, metadata.column_name_embeddings)
Expand Down

0 comments on commit 4b08f9d

Please sign in to comment.