Skip to content

Commit c8b99de

Browse files
update tests
1 parent 4a61724 commit c8b99de

File tree

7 files changed

+295
-276
lines changed

7 files changed

+295
-276
lines changed

.github/workflows/py_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
coverageFile: coverage.xml
8787
token: ${{ secrets.GITHUB_TOKEN }}
8888
thresholdAll: 0.7
89-
thresholdNew: 0.8
89+
thresholdNew: 0.7
9090

9191
- uses: actions/upload-artifact@v4
9292
if: github.event_name == 'pull_request'

similarity_framework/src/impl/comparator/comparator_by_column.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from statistics import mean
88

99
from logging_ import logger
10+
from similarity_framework.src.impl.comparator.distance_functions import HausdorffDistanceMin, AverageDist
1011
from similarity_framework.src.impl.comparator.utils import cosine_sim, are_columns_null
1112
from similarity_framework.src.interfaces.comparator.comparator import HandlerType, Comparator
1213
from similarity_framework.src.models.metadata import Metadata, KindMetadata, CategoricalMetadata
@@ -272,8 +273,8 @@ def compare_constants(
272273
value: float = 0 if metadata1.value == metadata2.value else 1
273274
else:
274275
value = 1 - cosine_sim(
275-
metadata1.value_embeddings,
276-
metadata2.value_embeddings,
276+
metadata1.value_embeddings[0], # todo 0 nebo 1
277+
metadata2.value_embeddings[0],
277278
)
278279
# if nulls are equal and exist
279280
if nulls == 0 and metadata1.nulls:
@@ -446,6 +447,9 @@ def from_settings(settings: AnalysisSettings) -> "ComparatorByColumn":
446447
comparator.add_comparator_type(ColumnKindHandler(weight=settings.weights.kinds))
447448
if settings.type_basic or settings.type_structural or settings.type_advanced:
448449
comparator.add_comparator_type(ColumnTypeHandler(settings.weights.type))
450+
if settings.distance_function:
451+
func = HausdorffDistanceMin() if settings.distance_function == "HausdorffDistanceMin" else AverageDist()
452+
comparator.set_distance_function(func)
449453
return comparator
450454

451455
def __init__(self):

similarity_framework/src/impl/comparator/comparator_by_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def compare_constants(self, metadata1: Metadata, metadata2: Metadata) -> pd.Data
353353
value_re.loc[column1, column2] = int(meta1.value != meta2.value)
354354
else:
355355
value_re.loc[column1, column2] = 1 - cosine_sim(
356-
meta1.value_embeddings,
357-
meta2.value_embeddings,
356+
meta1.value_embeddings[0], #todo 0 nebo 1
357+
meta2.value_embeddings[0],
358358
)
359359

360360
# 0 distance if values are the same otherwise 1

similarity_framework/src/impl/comparator/distance_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,19 @@ def compute(self, distance_table: pd.DataFrame) -> float:
1818
row_mins = distance_table.min(axis=1)
1919
column_mins = distance_table.min(axis=0)
2020
return min(row_mins.max(), column_mins.max())
21+
22+
class AverageDist(DistanceFunction):
23+
"""Hausdorff distance class"""
24+
25+
def compute(self, distance_table: pd.DataFrame) -> float:
26+
"""
27+
Compute Hausdorff distance with min function.
28+
:param distance_table: dataframe
29+
:return: float between 0 and 1
30+
"""
31+
if distance_table.size == 0:
32+
return np.nan
33+
row_avg = distance_table.min(axis=1)
34+
# column_avg = distance_table.min(axis=0)
35+
# return min(row_avg.mean(), column_avg.mean())
36+
return row_avg.mean()

similarity_framework/src/models/metadata.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
self.distribution = distribution
7070
self.nulls = null_values
7171
self.value = value
72-
self.value_embeddings = None if value[0] is not str else model.encode(list(value))
72+
self.value_embeddings = None if type(value[0]) is not str else model.encode(list(value))
7373

7474
def __str__(self):
7575
return f"BoolMetadata(values={self.value},distribution={self.distribution}, null_values={self.nulls})"
@@ -91,8 +91,8 @@ def __init__(
9191
self.nulls = null_values
9292
self.longest = longest
9393
self.shortest = shortest
94-
self.longest_embeddings = None if longest is not str else model.encode(longest)
95-
self.shortest_embeddings = None if shortest is not str else model.encode(shortest)
94+
self.longest_embeddings = None if type(longest) is not str else model.encode(longest)
95+
self.shortest_embeddings = None if type(shortest) is not str else model.encode(shortest)
9696
self.ratio_max_length = ratio_max_length
9797

9898
def __str__(self):
@@ -114,7 +114,8 @@ def __init__(
114114
self.nulls = null_values
115115
self.value = value
116116
self.distribution = distribution
117-
self.value_embeddings = None if value[0] is not str else model.encode(list(value))
117+
#model.encode(list(value)).view(-1, 1)
118+
self.value_embeddings = None if type(value[0]) is not str else model.encode(list(value))
118119

119120
def __str__(self):
120121
return f"ConstantMetadata(values={self.value}, null_values={self.nulls}, distribution={self.distribution})"

similarity_framework/src/models/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pydantic import Field, BaseModel, AliasChoices
22
from pydantic_settings import BaseSettings, SettingsConfigDict
33

4+
from similarity_framework.src.impl.comparator.distance_functions import HausdorffDistanceMin
5+
from similarity_framework.src.interfaces.common import DistanceFunction
6+
47

58
class WeightSettings(BaseModel):
69
column_embeddings: int = Field(1, description="Weight for column embeddings")
@@ -22,6 +25,7 @@ class AnalysisSettings(BaseSettings):
2225
type_structural: bool = Field(default=False, description="Use structural type comparison")
2326
type_basic: bool = Field(default=False, description="Use basic type comparison")
2427
kinds: bool = Field(default=False, description="Use kinds for comparison")
28+
distance_function: str = Field(default="HausdorffDistanceMin", description="Distance function for comparison")
2529

2630
## only for comparator
2731
size: bool = Field(default=False, description="Use size for comparison")

0 commit comments

Comments
 (0)