Skip to content

Commit ba41d70

Browse files
update tests
1 parent c8b99de commit ba41d70

File tree

5 files changed

+80
-5
lines changed

5 files changed

+80
-5
lines changed

similarity_framework/src/impl/comparator/comparator_by_column.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def compare_constants(
273273
value: float = 0 if metadata1.value == metadata2.value else 1
274274
else:
275275
value = 1 - cosine_sim(
276-
metadata1.value_embeddings[0], # todo 0 nebo 1
276+
metadata1.value_embeddings[0], # todo 0 nebo 1
277277
metadata2.value_embeddings[0],
278278
)
279279
# if nulls are equal and exist

similarity_framework/src/impl/comparator/comparator_by_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def compare_constants(self, metadata1: Metadata, metadata2: Metadata) -> pd.Data
353353
value_re.loc[column1, column2] = int(meta1.value != meta2.value)
354354
else:
355355
value_re.loc[column1, column2] = 1 - cosine_sim(
356-
meta1.value_embeddings[0], #todo 0 nebo 1
356+
meta1.value_embeddings[0], # todo 0 nebo 1
357357
meta2.value_embeddings[0],
358358
)
359359

similarity_framework/src/impl/comparator/distance_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def compute(self, distance_table: pd.DataFrame) -> float:
1919
column_mins = distance_table.min(axis=0)
2020
return min(row_mins.max(), column_mins.max())
2121

22+
2223
class AverageDist(DistanceFunction):
2324
"""Hausdorff distance class"""
2425

@@ -33,4 +34,4 @@ def compute(self, distance_table: pd.DataFrame) -> float:
3334
row_avg = distance_table.min(axis=1)
3435
# column_avg = distance_table.min(axis=0)
3536
# return min(row_avg.mean(), column_avg.mean())
36-
return row_avg.mean()
37+
return row_avg.mean()

similarity_framework/src/models/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
self.nulls = null_values
115115
self.value = value
116116
self.distribution = distribution
117-
#model.encode(list(value)).view(-1, 1)
117+
# model.encode(list(value)).view(-1, 1)
118118
self.value_embeddings = None if type(value[0]) is not str else model.encode(list(value))
119119

120120
def __str__(self):

tests/similarity_framework/test_similarity_comparator.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
ColumnExactNamesHandler as ColumnExactNamesHandlerByColumn,
1212
ColumnKindHandler, ColumnEmbeddingsHandler
1313
)
14-
from similarity_framework.src.impl.comparator.utils import concat
14+
from similarity_framework.src.impl.comparator.distance_functions import AverageDist
15+
from similarity_framework.src.impl.comparator.utils import concat, cosine_sim, fill_result, are_columns_null, create_string_from_columns
1516
from similarity_framework.src.models.metadata import MetadataCreatorInput
1617
from similarity_framework.src.models.similarity import Settings
1718
from similarity_framework.src.impl.metadata.type_metadata_creator import TypeMetadataCreator
@@ -40,12 +41,85 @@ def test_hausdorff_min(self):
4041
self.assertEqual(HausdorffDistanceMin().compute(df3), 1)
4142
self.assertEqual(HausdorffDistanceMin().compute(df4), 3)
4243

44+
def test_average_dist(self):
45+
df1 = pd.DataFrame([(2, 3, 3), (1, 4, 2), (5, 1, 2)])
46+
df2 = pd.DataFrame([(7, 2, 2), (8, 3, 4), (9, 2, 5)])
47+
df3 = pd.DataFrame([(1, 1, 3), (1, 2, 3), (1, -1, 2)])
48+
df4 = pd.DataFrame([(5, 3, 4), (2, 8, 8), (1, 100, 100)])
49+
self.assertEqual(AverageDist().compute(df1), 4/3)
50+
self.assertEqual(AverageDist().compute(df2), 7/3)
51+
self.assertEqual(AverageDist().compute(df3), 1/3)
52+
self.assertEqual(AverageDist().compute(df4), 6/3)
53+
4354
def test_get_ratio(self):
4455
self.assertEqual(round(get_ratio(3, 5), 2), 1.67)
4556
self.assertEqual(round(get_ratio(5, 3), 2), 1.67)
4657
self.assertEqual(round(get_ratio(15, 9), 2), 1.67)
4758
self.assertEqual(round(get_ratio(9, 15), 2), 1.67)
4859

60+
def test_cosine_sim(self):
61+
self.assertEqual(cosine_sim([1, 2, 3], [1, 2, 3]), 1)
62+
self.assertEqual(cosine_sim([1, 2, 3], [3, 2, 1]), 0.714)
63+
self.assertEqual(cosine_sim([1, 2, 3], [1, 2, 4]), 0.991)
64+
self.assertEqual(cosine_sim([1, 2, 3], [1, 2, 2]), 0.98)
65+
self.assertEqual(cosine_sim([1, 2, 3], [-1, -2, -3]), -1)
66+
67+
def test_fill_result(self):
68+
metadata1_names = {0: 'a', 1: 'b', 2: 'c'}
69+
metadata2_names = {0: 'a', 1: 'b', 2: 'd'}
70+
data = {
71+
0: [0.0, 1.0, 1.0],
72+
1: [1.0, 0.0, 1.0],
73+
2: [1.0, 1.0, 1.0]
74+
}
75+
76+
res = pd.DataFrame(data)
77+
print(res)
78+
self.assertTrue(fill_result(metadata1_names, metadata2_names).equals(res))
79+
80+
def test_create_string_from_columns(self):
81+
# Create sample data
82+
df1 = pd.DataFrame({
83+
'col1': [1, 2, 3],
84+
'col2': [4, 5, 6]
85+
})
86+
df2 = pd.DataFrame({
87+
'col1': [7, 8, 9],
88+
'col2': [10, 11, 12]
89+
})
90+
database = [df1, df2]
91+
table_names = ['table1', 'table2']
92+
93+
# Expected results
94+
expected_sentences = [
95+
'1, 2, 3', '4, 5, 6',
96+
'7, 8, 9', '10, 11, 12'
97+
]
98+
expected_sentences_datasets = [
99+
'table1', 'table1',
100+
'table2', 'table2'
101+
]
102+
103+
# Run the function
104+
sentences, sentences_datasets = create_string_from_columns(database, table_names)
105+
106+
# Assert the results
107+
self.assertEqual(sentences, expected_sentences)
108+
self.assertEqual(sentences_datasets, expected_sentences_datasets)
109+
110+
class TestAreColumnsNull(unittest.TestCase):
111+
def test_both_columns_empty(self):
112+
self.assertEqual(are_columns_null(set(), set(), "Test message"), (True, 0))
113+
114+
def test_first_column_empty(self):
115+
self.assertEqual(are_columns_null(set(), {1, 2, 3}, "Test message"), (True, 1))
116+
117+
def test_second_column_empty(self):
118+
self.assertEqual(are_columns_null({1, 2, 3}, set(), "Test message"), (True, 1))
119+
120+
def test_both_columns_non_empty(self):
121+
self.assertEqual(are_columns_null({1, 2, 3}, {4, 5, 6}, "Test message"), (False, 0))
122+
49123

50124
class TestSingleSpecificComparator(unittest.TestCase):
51125
def setUp(self):

0 commit comments

Comments
 (0)