Skip to content

Commit

Permalink
Merge branch 'postproc' of github.com:OML-Team/open-metric-learning i…
Browse files Browse the repository at this point in the history
…nto th
  • Loading branch information
AlekseySh committed Jun 4, 2024
2 parents 2c91fe9 + 42adcdf commit 016399e
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 39 deletions.
17 changes: 17 additions & 0 deletions docs/readme/examples_source/retrieval_format.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[comment]:dataset-start
```python
from oml.utils import get_mock_texts_dataset, get_mock_images_dataset
from oml.utils.dataframe_format import check_retrieval_dataframe_format

# IMAGES
df_train, df_val = get_mock_images_dataset(global_paths=True)
check_retrieval_dataframe_format(df=df_train)
check_retrieval_dataframe_format(df=df_val)

# TEXTS
df_train, df_val = get_mock_texts_dataset()
check_retrieval_dataframe_format(df=df_train)
check_retrieval_dataframe_format(df=df_val)

```
[comment]:dataset-end
6 changes: 6 additions & 0 deletions docs/source/contents/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ TripletLoss

.. automethod:: __init__
.. automethod:: forward
.. autoproperty:: last_logs


TripletLossPlain
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -24,6 +26,7 @@ TripletLossPlain

.. automethod:: __init__
.. automethod:: forward
.. autoproperty:: last_logs

TripletLossWithMiner
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -33,6 +36,7 @@ TripletLossWithMiner

.. automethod:: __init__
.. automethod:: forward
.. autoproperty:: last_logs

SurrogatePrecision
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -50,6 +54,7 @@ ArcFaceLoss
:show-inheritance:

.. automethod:: __init__
.. autoproperty:: last_logs

ArcFaceLossWithMLP
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -58,6 +63,7 @@ ArcFaceLossWithMLP
:show-inheritance:

.. automethod:: __init__
.. autoproperty:: last_logs

label_smoothing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions docs/source/contents/miners.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ MinerWithBank

.. automethod:: __init__
.. automethod:: sample
.. autoproperty:: last_logs
7 changes: 5 additions & 2 deletions docs/source/contents/retrieval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ RetrievalResults
.. automethod:: __init__
.. automethod:: from_embeddings
.. automethod:: visualize
.. automethod:: n_retrieved_items
.. autoproperty:: n_retrieved_items
.. autoproperty:: distances
.. autoproperty:: retrieved_ids
.. autoproperty:: gt_ids

PairwiseReranker
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -26,4 +29,4 @@ PairwiseReranker

.. automethod:: __init__
.. automethod:: process
.. automethod:: top_n
.. autoproperty:: top_n
8 changes: 1 addition & 7 deletions docs/source/oml/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,4 @@ Check out the
`examples <https://drive.google.com/drive/folders/12QmUbDrKk7UaYGHreQdz5_nPfXG3klNc?usp=sharing>`_
of dataframes. You can also use helper to check if your dataset is in the right format:

.. code-block:: python
import pandas as pd
from oml.utils.dataframe_format import check_retrieval_dataframe_format
check_retrieval_dataframe_format(df=pd.read_csv("/path/to/table.csv"), dataset_root="/path/to/dataset/root/")
.. mdinclude:: ../../../docs/readme/examples_source/retrieval_format.md
2 changes: 1 addition & 1 deletion oml/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_cache_folder() -> Path:
Y1_COLUMN = "y_1"
Y2_COLUMN = "y_2"

OBLIGATORY_COLUMNS = [LABELS_COLUMN, PATHS_COLUMN, SPLIT_COLUMN, IS_QUERY_COLUMN, IS_GALLERY_COLUMN]
OBLIGATORY_COLUMNS = [LABELS_COLUMN, SPLIT_COLUMN, IS_QUERY_COLUMN, IS_GALLERY_COLUMN]
BBOXES_COLUMNS = [X1_COLUMN, X2_COLUMN, Y1_COLUMN, Y2_COLUMN]

# Keys for interactions among our classes (datasets, metrics and so on)
Expand Down
20 changes: 18 additions & 2 deletions oml/losses/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self.sin_m = np.sin(m)
self.th = -self.cos_m
self.mm = self.sin_m * m
self.last_logs: Dict[str, float] = {}
self._last_logs: Dict[str, float] = {}

def fc(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(F.normalize(x, p=2), F.normalize(self.weight, p=2))
Expand Down Expand Up @@ -94,7 +94,15 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

@torch.no_grad()
def _log_accuracy_on_batch(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
self.last_logs["accuracy"] = torch.mean((y == torch.argmax(logits, 1)).to(torch.float32))
self._last_logs["accuracy"] = torch.mean((y == torch.argmax(logits, 1)).to(torch.float32))

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self._last_logs


class ArcFaceLossWithMLP(nn.Module):
Expand Down Expand Up @@ -144,5 +152,13 @@ def __init__(
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return self.arcface(self.mlp(x), y)

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self.arcface.last_logs


__all__ = ["ArcFaceLoss", "ArcFaceLossWithMLP"]
46 changes: 34 additions & 12 deletions oml/losses/triplet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(self, margin: Optional[float], reduction: str = "mean", need_logs:
Args:
margin: Margin value, set ``None`` to use `SoftTripletLoss`
reduction: ``mean``, ``sum`` or ``none``
need_logs: Set ``True`` if you want to store logs
need_logs: Set ``True`` to store some information to track in ``self.last_logs`` property.
"""
assert reduction in ("mean", "sum", "none")
Expand All @@ -46,7 +46,7 @@ def __init__(self, margin: Optional[float], reduction: str = "mean", need_logs:
self.margin = margin
self.reduction = reduction
self.need_logs = need_logs
self.last_logs: Dict[str, float] = {}
self._last_logs: Dict[str, float] = {}

def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
"""
Expand All @@ -72,7 +72,7 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
loss = torch.relu(self.margin + positive_dist - negative_dist)

if self.need_logs:
self.last_logs = {
self._last_logs = {
"active_tri": float((loss.clone().detach() > 0).float().mean()),
"pos_dist": float(positive_dist.clone().detach().mean().item()),
"neg_dist": float(negative_dist.clone().detach().mean().item()),
Expand All @@ -82,6 +82,14 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:

return loss

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self._last_logs


def get_tri_ids_in_plain(n: int) -> Tuple[List[int], List[int], List[int]]:
"""
Expand Down Expand Up @@ -120,15 +128,14 @@ def __init__(self, margin: Optional[float], reduction: str = "mean", need_logs:
Args:
margin: Margin value, set ``None`` to use `SoftTripletLoss`
reduction: ``mean``, ``sum`` or ``none``
need_logs: Set ``True`` if you want to store logs
need_logs: Set ``True`` to store some information to track in ``self.last_logs`` property.
"""
assert reduction in ("mean", "sum", "none")
assert (margin is None) or (margin > 0)

super(TripletLossPlain, self).__init__()
self.criterion = TripletLoss(margin=margin, reduction=reduction, need_logs=need_logs)
self.last_logs = self.criterion.last_logs

def forward(self, features: torch.Tensor) -> Tensor:
"""
Expand All @@ -150,10 +157,17 @@ def forward(self, features: torch.Tensor) -> Tensor:
anchor_ii, positive_ii, negative_ii = get_tri_ids_in_plain(n)

loss = self.criterion(features[anchor_ii], features[positive_ii], features[negative_ii])
self.last_logs = self.criterion.last_logs

return loss

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self.criterion.last_logs


class TripletLossWithMiner(ITripletLossWithMiner):
"""
Expand All @@ -176,7 +190,7 @@ def __init__(
margin: Margin value, set ``None`` to use `SoftTripletLoss`
miner: A miner that implements the logic of picking triplets to pass them to the triplet loss.
reduction: ``mean``, ``sum`` or ``none``
need_logs: Set ``True`` if you want to store logs
need_logs: Set ``True`` to store some information to track in ``self.last_logs`` property.
"""
assert reduction in ("mean", "sum", "none")
Expand All @@ -188,7 +202,7 @@ def __init__(
self.reduction = reduction
self.need_logs = need_logs

self.last_logs: Dict[str, float] = {}
self._last_logs: Dict[str, float] = {}

def forward(self, features: Tensor, labels: Union[Tensor, List[int]]) -> Tensor:
"""
Expand All @@ -215,7 +229,7 @@ def avg_d(x1: Tensor, x2: Tensor) -> Tensor:

is_bank_tri = ~is_orig_tri
active = (loss.clone().detach() > 0).float()
self.last_logs.update(
self._last_logs.update(
{
"orig_active_tri": active[is_orig_tri].sum() / is_orig_tri.sum(),
"bank_active_tri": active[is_bank_tri].sum() / is_bank_tri.sum(),
Expand All @@ -230,8 +244,8 @@ def avg_d(x1: Tensor, x2: Tensor) -> Tensor:
anchor, positive, negative = self.miner.sample(features=features, labels=labels_list)
loss = self.tri_loss(anchor=anchor, positive=positive, negative=negative)

self.last_logs.update(self.tri_loss.last_logs)
self.last_logs.update(getattr(self.miner, "last_logs", {}))
self._last_logs.update(self.tri_loss.last_logs)
self._last_logs.update(getattr(self.miner, "last_logs", {}))

if self.reduction == "mean":
loss = loss.mean()
Expand All @@ -244,5 +258,13 @@ def avg_d(x1: Tensor, x2: Tensor) -> Tensor:

return loss

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self._last_logs


__all__ = ["TLogs", "TripletLoss", "get_tri_ids_in_plain", "TripletLossPlain", "TripletLossWithMiner"]
16 changes: 12 additions & 4 deletions oml/miners/miner_with_bank.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import Tensor, no_grad
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(
Args:
bank_size_in_batches: Size of the bank.
miner: Miner, for now we only support ``NHardTripletsMiner``
need_logs: Set ``True`` if you want to track logs.
need_logs: Set ``True`` to store some information to track in ``self.last_logs`` property.
"""

Expand All @@ -46,7 +46,7 @@ def __init__(
self.ptr = 0

self.need_logs = need_logs
self.last_logs: Dict[str, float] = {}
self._last_logs: Dict[str, float] = {}

@no_grad()
def __allocate_if_needed(self, features: Tensor, labels: Tensor) -> None:
Expand Down Expand Up @@ -103,7 +103,7 @@ def sample(self, features: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor, Tens
)

if self.need_logs:
self.last_logs = self._prepare_logs(
self._last_logs = self._prepare_logs(
ids_a=ids_a, ids_p=ids_p, ids_n=ids_n, ignore_anchor_mask=ignore_anchor_mask
)

Expand Down Expand Up @@ -138,5 +138,13 @@ def _prepare_logs(

return logs.get_dict_with_results()

@property
def last_logs(self) -> Dict[str, Any]:
"""
Returns:
Dictionary containing useful statistic calculated for the last batch.
"""
return self._last_logs


__all__ = ["MinerWithBank"]
4 changes: 4 additions & 0 deletions oml/retrieval/postprocessors/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(

@property
def top_n(self) -> int:
"""
Returns:
Number of gallery items closest to each query to process.
"""
return self._top_n

def process(self, rr: RetrievalResults, dataset: IQueryGalleryDataset) -> RetrievalResults: # type: ignore
Expand Down
39 changes: 33 additions & 6 deletions oml/retrieval/retrieval_results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pprint import pformat
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import pandas as pd
Expand Down Expand Up @@ -56,15 +56,42 @@ def __init__(
if any(len(x) == 0 for x in gt_ids):
raise RuntimeError("Every query must have at least one relevant gallery id.")

self.distances = tuple(distances)
self.retrieved_ids = tuple(retrieved_ids)
self.gt_ids = tuple(gt_ids) if gt_ids is not None else None
self._distances = tuple(distances)
self._retrieved_ids = tuple(retrieved_ids)
self._gt_ids = tuple(gt_ids) if gt_ids is not None else None

@property
def distances(self) -> Tuple[FloatTensor, ...]:
"""
Returns:
Sorted distances from queries to the first gallery items with the size of ``n_query``.
"""
return self._distances

@property
def retrieved_ids(self) -> Tuple[LongTensor, ...]:
"""
Returns:
First gallery indices retrieved for every query with the size of ``n_query``.
Every index is within the range ``(0, n_gallery - 1)``.
"""
return self._retrieved_ids

@property
def gt_ids(self) -> Optional[Tuple[LongTensor, ...]]:
"""
Returns:
Gallery indices relevant to every query with the size of ``n_query``.
Every element is within the range ``(0, n_gallery - 1)``
"""
return self._gt_ids

@property
def n_retrieved_items(self) -> int:
"""
Returns: Number of items retrieved for each query. If queries have different number of retrieved items,
returns the maximum of them.
Returns:
Number of items retrieved for each query. If queries have different number of retrieved items,
returns the maximum of them.
"""
return max(len(x) for x in self.retrieved_ids)
Expand Down
Loading

0 comments on commit 016399e

Please sign in to comment.