From 891d6184b611a042575dab5c0601767ea692c0dd Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 17 Jun 2024 23:31:25 +0600 Subject: [PATCH 1/2] fixed bottleneck in postprocessing --- oml/datasets/images.py | 7 +++++-- oml/datasets/texts.py | 7 +++++-- oml/retrieval/postprocessors/pairwise.py | 6 ++++-- oml/utils/misc_torch.py | 2 +- tests/test_integrations/test_lightning/test_pipeline.py | 7 +++++-- tests/test_integrations/utils.py | 8 ++++---- 6 files changed, 24 insertions(+), 13 deletions(-) diff --git a/oml/datasets/images.py b/oml/datasets/images.py index d92cc84ac..aa6332ff9 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -267,11 +267,14 @@ def __init__( labels_key=labels_key, ) + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + def get_query_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + return self._gallery_ids class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): diff --git a/oml/datasets/texts.py b/oml/datasets/texts.py index d6aef89c7..73c98e53f 100644 --- a/oml/datasets/texts.py +++ b/oml/datasets/texts.py @@ -189,11 +189,14 @@ def __init__( index_key=index_key, ) + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + def get_query_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + return self._gallery_ids class TextQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index ccd81b68f..9f23541bf 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -117,9 +117,11 @@ def _process_raw( # Queries have different number of retrieved items, so we track what pairs are relevant to what queries (bounds) pairs = [] bounds = [0] + query_ids = dataset.get_query_ids() + gallery_ids = dataset.get_gallery_ids() for iq, ids_gallery in enumerate(retrieved_ids): - ids_gallery_global = dataset.get_gallery_ids()[ids_gallery][: self.top_n].tolist() - ids_query_global = [dataset.get_query_ids()[iq].item()] * len(ids_gallery_global) + ids_gallery_global = gallery_ids[ids_gallery][: self.top_n].tolist() + ids_query_global = [query_ids[iq].item()] * len(ids_gallery_global) pairs.extend(list(zip(ids_query_global, ids_gallery_global))) bounds.append(bounds[-1] + len(ids_gallery_global)) diff --git a/oml/utils/misc_torch.py b/oml/utils/misc_torch.py index 75009d3c3..defd75d6a 100644 --- a/oml/utils/misc_torch.py +++ b/oml/utils/misc_torch.py @@ -7,7 +7,7 @@ import torch from torch import Tensor, cdist -TSingleValues = Union[int, float, np.float_, np.int_, torch.Tensor] +TSingleValues = Union[int, float, np.float64, np.int_, torch.Tensor] TSequenceValues = Union[List[float], Tuple[float, ...], np.ndarray, torch.Tensor] TOnlineValues = Union[TSingleValues, TSequenceValues] diff --git a/tests/test_integrations/test_lightning/test_pipeline.py b/tests/test_integrations/test_lightning/test_pipeline.py index 99fef832c..c4cb64f5a 100644 --- a/tests/test_integrations/test_lightning/test_pipeline.py +++ b/tests/test_integrations/test_lightning/test_pipeline.py @@ -24,6 +24,9 @@ def __init__(self, labels: List[int], im_size: int): self.im_size = im_size self.extra_data = dict() + self._query_ids = torch.arange(len(self)).long() + self._gallery_ids = torch.arange(len(self)).long() + def __getitem__(self, item: int) -> Dict[str, Any]: input_tensors = torch.rand((3, self.im_size, self.im_size)) label = torch.tensor(self.labels[item]).long() @@ -40,10 +43,10 @@ def get_labels(self) -> np.ndarray: return np.array(self.labels) def get_query_ids(self) -> LongTensor: - return torch.arange(len(self)).long() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return torch.arange(len(self)).long() + return self._gallery_ids class DummyCommonModule(pl.LightningModule): diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index bf86c3277..f642e498b 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -48,8 +48,8 @@ def __init__( assert len(embeddings) == len(is_query) == len(is_gallery) self._embeddings = embeddings - self._is_query = is_query - self._is_gallery = is_gallery + self._query_ids = is_query.nonzero().squeeze() + self._gallery_ids = is_gallery.nonzero().squeeze() self.extra_data = {} if categories is not None: @@ -81,10 +81,10 @@ def __len__(self) -> int: return len(self._embeddings) def get_query_ids(self) -> LongTensor: - return self._is_query.nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return self._is_gallery.nonzero().squeeze() + return self._gallery_ids class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset): From 140d9dbec620073c61be97ee7267474c277c49f6 Mon Sep 17 00:00:00 2001 From: alekseysh Date: Mon, 17 Jun 2024 23:47:39 +0600 Subject: [PATCH 2/2] made pip instruction shorter --- README.md | 6 +----- docs/readme/installation.md | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 755011cfa..17305fa20 100644 --- a/README.md +++ b/README.md @@ -335,11 +335,7 @@ for our paper ## [Installation](https://open-metric-learning.readthedocs.io/en/latest/oml/installation.html) ```shell -pip install -U open-metric-learning -``` - -If you need OML for NLP, install the extra requirements with: -```shell +pip install -U open-metric-learning; # minimum dependencies pip install -U open-metric-learning[nlp] ``` diff --git a/docs/readme/installation.md b/docs/readme/installation.md index 098b56ade..92b54120a 100644 --- a/docs/readme/installation.md +++ b/docs/readme/installation.md @@ -1,9 +1,5 @@ ```shell -pip install -U open-metric-learning -``` - -If you need OML for NLP, install the extra requirements with: -```shell +pip install -U open-metric-learning; # minimum dependencies pip install -U open-metric-learning[nlp] ```