diff --git a/README.md b/README.md index 755011cfa..17305fa20 100644 --- a/README.md +++ b/README.md @@ -335,11 +335,7 @@ for our paper ## [Installation](https://open-metric-learning.readthedocs.io/en/latest/oml/installation.html) ```shell -pip install -U open-metric-learning -``` - -If you need OML for NLP, install the extra requirements with: -```shell +pip install -U open-metric-learning; # minimum dependencies pip install -U open-metric-learning[nlp] ``` diff --git a/docs/readme/installation.md b/docs/readme/installation.md index 098b56ade..92b54120a 100644 --- a/docs/readme/installation.md +++ b/docs/readme/installation.md @@ -1,9 +1,5 @@ ```shell -pip install -U open-metric-learning -``` - -If you need OML for NLP, install the extra requirements with: -```shell +pip install -U open-metric-learning; # minimum dependencies pip install -U open-metric-learning[nlp] ``` diff --git a/oml/datasets/images.py b/oml/datasets/images.py index d92cc84ac..aa6332ff9 100644 --- a/oml/datasets/images.py +++ b/oml/datasets/images.py @@ -267,11 +267,14 @@ def __init__( labels_key=labels_key, ) + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + def get_query_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + return self._gallery_ids class ImageQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): diff --git a/oml/datasets/texts.py b/oml/datasets/texts.py index d6aef89c7..73c98e53f 100644 --- a/oml/datasets/texts.py +++ b/oml/datasets/texts.py @@ -189,11 +189,14 @@ def __init__( index_key=index_key, ) + self._query_ids = BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + self._gallery_ids = BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + def get_query_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_QUERY_COLUMN]).nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return BoolTensor(self.df[IS_GALLERY_COLUMN]).nonzero().squeeze() + return self._gallery_ids class TextQueryGalleryDataset(IVisualizableDataset, IQueryGalleryDataset): diff --git a/oml/retrieval/postprocessors/pairwise.py b/oml/retrieval/postprocessors/pairwise.py index ccd81b68f..9f23541bf 100644 --- a/oml/retrieval/postprocessors/pairwise.py +++ b/oml/retrieval/postprocessors/pairwise.py @@ -117,9 +117,11 @@ def _process_raw( # Queries have different number of retrieved items, so we track what pairs are relevant to what queries (bounds) pairs = [] bounds = [0] + query_ids = dataset.get_query_ids() + gallery_ids = dataset.get_gallery_ids() for iq, ids_gallery in enumerate(retrieved_ids): - ids_gallery_global = dataset.get_gallery_ids()[ids_gallery][: self.top_n].tolist() - ids_query_global = [dataset.get_query_ids()[iq].item()] * len(ids_gallery_global) + ids_gallery_global = gallery_ids[ids_gallery][: self.top_n].tolist() + ids_query_global = [query_ids[iq].item()] * len(ids_gallery_global) pairs.extend(list(zip(ids_query_global, ids_gallery_global))) bounds.append(bounds[-1] + len(ids_gallery_global)) diff --git a/tests/test_integrations/test_lightning/test_pipeline.py b/tests/test_integrations/test_lightning/test_pipeline.py index 99fef832c..c4cb64f5a 100644 --- a/tests/test_integrations/test_lightning/test_pipeline.py +++ b/tests/test_integrations/test_lightning/test_pipeline.py @@ -24,6 +24,9 @@ def __init__(self, labels: List[int], im_size: int): self.im_size = im_size self.extra_data = dict() + self._query_ids = torch.arange(len(self)).long() + self._gallery_ids = torch.arange(len(self)).long() + def __getitem__(self, item: int) -> Dict[str, Any]: input_tensors = torch.rand((3, self.im_size, self.im_size)) label = torch.tensor(self.labels[item]).long() @@ -40,10 +43,10 @@ def get_labels(self) -> np.ndarray: return np.array(self.labels) def get_query_ids(self) -> LongTensor: - return torch.arange(len(self)).long() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return torch.arange(len(self)).long() + return self._gallery_ids class DummyCommonModule(pl.LightningModule): diff --git a/tests/test_integrations/utils.py b/tests/test_integrations/utils.py index bf86c3277..f642e498b 100644 --- a/tests/test_integrations/utils.py +++ b/tests/test_integrations/utils.py @@ -48,8 +48,8 @@ def __init__( assert len(embeddings) == len(is_query) == len(is_gallery) self._embeddings = embeddings - self._is_query = is_query - self._is_gallery = is_gallery + self._query_ids = is_query.nonzero().squeeze() + self._gallery_ids = is_gallery.nonzero().squeeze() self.extra_data = {} if categories is not None: @@ -81,10 +81,10 @@ def __len__(self) -> int: return len(self._embeddings) def get_query_ids(self) -> LongTensor: - return self._is_query.nonzero().squeeze() + return self._query_ids def get_gallery_ids(self) -> LongTensor: - return self._is_gallery.nonzero().squeeze() + return self._gallery_ids class EmbeddingsQueryGalleryLabeledDataset(EmbeddingsQueryGalleryDataset, IQueryGalleryLabeledDataset):