-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Made inference modality agnostic in re-ranking and other parts of the repo #542
Conversation
return self.distances_to_return | ||
|
||
|
||
@pytest.mark.long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this test because it was too tricky and it's hard to support it when changing interfaces.
return distances_upd | ||
distances_top = distances_top.view(distances.shape[0], top_n) | ||
|
||
distances_upd, ii_rerank = distances_top.sort() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This inconsistent function for ii_rerank
... 😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do u mean?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added
# todo 522: explain what's going on here
so, when all interfaces are settled i will add more explanations there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added examples, I hope it helps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sry, it's just sort
that returns random indices for the same values, like in metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooh, got it, you are right
the same problem
I hope we will solve it at some time...
REPRODUCING OUR POSTPROCESSING PAPER: InShop validation with pp:
SOP:
I've used models we trained before:
|
@@ -1,7 +1,7 @@ | |||
postfix: "postprocessing" | |||
|
|||
seed: 42 | |||
precision: 16 | |||
precision: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i checked that there is no precision 16 for cpu, so, that value was confusing
@@ -72,7 +72,7 @@ def cat_two_sorted_tensors_and_keep_it_sorted(x1: Tensor, x2: Tensor, eps: float | |||
assert eps >= 0 | |||
assert x1.shape[0] == x2.shape[0] | |||
|
|||
scale = (x2[:, 0] / x1[:, -1]).view(-1, 1) | |||
scale = (x2[:, 0] / x1[:, -1]).view(-1, 1).type_as(x1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error apeared in half precision, so i added type_as
@@ -11,7 +11,7 @@ | |||
def main_hydra(cfg: DictConfig) -> None: | |||
cfg = dictconfig_to_dict(cfg) | |||
download_mock_dataset(MOCK_DATASET_PATH) | |||
cfg["data_dir"] = MOCK_DATASET_PATH | |||
cfg["data_dir"] = str(MOCK_DATASET_PATH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the idea here and in other similar changes below is to have the same types as in real runs of our pipelines
Changelog (all the functions and classes on the right side are modality agnostic):
EmbeddingPairsDataset
,ImagePairsDataset
->PairDataset
pairwise_inference_on_images
,pairwise_inference_on_embeddings
->pairwise_inference
IDistancesPostprocessor
-> (mostly renamed) ->IRetrievalPostprocessor
PairwisePostprocessor
,PairwiseEmbeddingsPostprocessor
,PairwiseImagesPostprocessor
->PairwiseReranker
inference_on_images
->inference
inference_on_dataframe
->inference_cached
Also:
EmbeddingMetrics
takes optionaldataset
argument in order to perform postprocessing.Examples changed:
train + val
andprediction
for postprocessorretrieval usage
global_paths
parameter todownload_mock_dataset
so it looks nicer