-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathtest_train_with_sequence.py
88 lines (61 loc) · 2.82 KB
/
test_train_with_sequence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pandas as pd
import pytest
import torch
from torch import nn
from tqdm import tqdm
from oml.const import LABELS_COLUMN, MOCK_DATASET_PATH, SEQUENCE_COLUMN
from oml.datasets.images import ImageQueryGalleryLabeledDataset
from oml.metrics.embeddings import EmbeddingMetrics, TMetricsDict_ByLabels
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.utils.misc import compare_dicts_recursively, set_global_seed
def validation(df: pd.DataFrame) -> TMetricsDict_ByLabels:
set_global_seed(42)
extractor = nn.Flatten()
val_dataset = ImageQueryGalleryLabeledDataset(df, dataset_root=MOCK_DATASET_PATH)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=0)
calculator = EmbeddingMetrics(extra_keys=("paths",), sequence_key=val_dataset.sequence_key, cmc_top_k=(1,))
calculator.setup(num_samples=len(val_dataset))
with torch.no_grad():
for batch in tqdm(val_loader):
batch["embeddings"] = extractor(batch["input_tensors"])
calculator.update_data(batch)
metrics = calculator.compute_metrics()
return metrics
def test_invariants_in_validation_with_sequences_1() -> None:
# We check that metrics don't change if we assign unique sequence id
# to every sample in validation set (so, ignoring logic is not applicable)
_, df = download_mock_dataset(MOCK_DATASET_PATH)
df_with_seq = df.copy()
df_with_seq[SEQUENCE_COLUMN] = list(range(len(df_with_seq)))
metrics = validation(df)
metrics_with_sequence = validation(df_with_seq)
assert compare_dicts_recursively(metrics_with_sequence, metrics)
def test_invariants_in_validation_with_sequences_2() -> None:
# We check that metrics don't change in the case, when we put
# a copy of every sample to gallery under the same sequence id
_, df = download_mock_dataset(MOCK_DATASET_PATH)
df_a = df.copy()
df_a[SEQUENCE_COLUMN] = list(range(len(df_a)))
df_a["is_query"] = True
df_a["is_gallery"] = True
df_b_1 = df_a.copy()
df_b_1["is_query"] = True
df_b_1["is_gallery"] = False
df_b_2 = df_a.copy()
df_b_2["is_query"] = False
df_b_2["is_gallery"] = True
df_b = pd.concat([df_b_1, df_b_2])
df_b = df_b.reset_index(drop=True)
metrics_a = validation(df_a)
metrics_b = validation(df_b)
assert compare_dicts_recursively(metrics_a, metrics_b)
def test_invariants_in_validation_with_sequences_3() -> None:
# If labels == sequence, then every sample has no pull of right answers in gallery
# to pick from. Thus, we expect our validation to produce the error.
_, df = download_mock_dataset(MOCK_DATASET_PATH)
df_with_seq = df.copy()
df_with_seq[SEQUENCE_COLUMN] = df_with_seq[LABELS_COLUMN]
with pytest.raises(RuntimeError):
validation(df)
validation(df_with_seq)
assert True