Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made inference modality agnostic in re-ranking and other parts of the repo #542

Merged
merged 27 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions oml/retrieval/postprocessors/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from oml.interfaces.datasets import IQueryGalleryDataset
from oml.interfaces.models import IPairwiseModel
from oml.interfaces.retrieval import IRetrievalPostprocessor
from oml.utils.misc_torch import cat_two_sorted_tensors_and_keep_it_sorted, take_2d
from oml.utils.misc_torch import (
assign_2d,
cat_two_sorted_tensors_and_keep_it_sorted,
take_2d,
)


class PairwiseReranker(IRetrievalPostprocessor):
Expand Down Expand Up @@ -55,16 +59,24 @@ def process(self, distances: Tensor, dataset: IQueryGalleryDataset) -> Tensor:

"""
# todo 522:
# This function is needed only during the migration time. We will directly use `process_neigh` later.
# Thus, the code above is just an adapter for input and output of the `process_neigh` function.
# This function and the code below is only needed during the migration time.
# We will directly use `process_neigh` later on.
# So, the code below is just a format adapter:
# 1) it takes the top (dists + ii) of the big distance matrix,
# 2) passes this top to the `process_neigh()`
# 3) puts the processed outputs on their places in the big distance matrix

assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids()))

distances, ii_retrieved = distances.sort()
distances, ii_retrieved_upd = self.process_neigh(
retrieved_ids=ii_retrieved, distances=distances, dataset=dataset
# we need this "+10" to activate rescaling if needed (so we have both: new and old distances in proces_neigh.
# anyway, this code is temporary
distances_top, ii_retrieved_top = torch.topk(
distances, k=min(self.top_n + 10, distances.shape[1]), largest=False
)
distances = take_2d(distances, ii_retrieved_upd.argsort())
distances_top_upd, ii_retrieved_upd = self.process_neigh(
retrieved_ids=ii_retrieved_top, distances=distances_top, dataset=dataset
)
distances = assign_2d(x=distances, indices=ii_retrieved_upd, new_values=distances_top_upd)

assert distances.shape == (len(dataset.get_query_ids()), len(dataset.get_gallery_ids()))

Expand Down
2 changes: 1 addition & 1 deletion oml/utils/misc_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float
assert eps >= 0
assert x1.shape[0] == x2.shape[0]

scale = (x2[:, 0] / x1[:, -1]).view(-1, 1)
scale = (x2[:, 0] / x1[:, -1]).view(-1, 1).type_as(x1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error apeared in half precision, so i added type_as

need_scaling = x1[:, -1] > x2[:, 0]
x1[need_scaling] = x1[need_scaling] * scale[need_scaling] - eps

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
postfix: "postprocessing"

seed: 42
precision: 16
precision: 32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i checked that there is no precision 16 for cpu, so, that value was confusing

accelerator: cpu
devices: 2
find_unused_parameters: False
Expand Down
Loading