-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing container for storing Retrieval Results #544
Conversation
oml/functional/knn.py
Outdated
assert (ids_query.ndim == 1) and (ids_gallery.ndim == 1) and (embeddings.ndim == 2) | ||
assert len(embeddings) <= len(ids_query) + len(ids_gallery) | ||
assert (sequence_ids is None) or (len(sequence_ids) == len(embeddings) and (sequence_ids.ndim == 1)) | ||
assert (labels_gt is None) or (len(labels_gt) <= len(ids_query) + len(ids_gallery) and (labels_gt.ndim == 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or len(labels_gt) == embeddings.shape[0]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
if labels_gt is not None: | ||
mask_gt_b = labels_gt[ids_query_b][..., None] == labels_gt[ids_gallery][None, ...] | ||
mask_gt_b[mask_to_ignore_b] = False | ||
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If labels_gt
is not None
we can allocate memory for gt_ids
before the for-loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we cannot because every query has arbitorary number of gt, so we have list of tensors as a result
oml/retrieval/retrieval_results.py
Outdated
|
||
if gt_ids is not None: | ||
assert distances.shape[0] == len(gt_ids) | ||
if not all(len(x) > 0 for x in gt_ids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be faster to evaluate any(len(x) == 0 for x in gt_ids)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you are right
oml/retrieval/retrieval_results.py
Outdated
cls, | ||
embeddings: FloatTensor, | ||
dataset: IQueryGalleryDataset, | ||
n_items_to_retrieve: int = 1_000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest to set default value for n_items_to_retrieve
to 10 or even 5. 1000 is looking too big.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets it be in the middle, i set 100 :)
oml/retrieval/retrieval_results.py
Outdated
|
||
return RetrievalResults(distances=distances, retrieved_ids=retrieved_ids, gt_ids=gt_ids) | ||
|
||
def __repr__(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me it looks like a violation of the agreement about the purpose of __repr__
function.
Called by the repr() built-in function to compute the “official” string representation of an object. If at all possible, this should look like a valid Python expression that could be used to recreate an object with the same value (given an appropriate environment). If this is not possible, a string of the form <...some useful description...> should be returned.
https://docs.python.org/3/reference/datamodel.html#object.__repr__
Maybe it would be better to use __str__
for this message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
find . -type d -name "__pycache__" -exec rm -r {} + | ||
find . -type f -name "*.log" -exec rm {} + | ||
find . -type f -name "*.predictions.json" -exec rm {} + | ||
rm -rf docs/build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Poor docs ((
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's annoying when u do search in the full repo and find some trash in docs html files
so yep
oml/functional/knn.py
Outdated
""" | ||
|
||
Args: | ||
embeddings: Matrix with the shape of ``[n, dim]`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usage of n
here and in top_n
confuse. Maybe L
(refer to len
) instead of lowercase n
would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
oml/functional/knn.py
Outdated
|
||
Args: | ||
embeddings: Matrix with the shape of ``[n, dim]`` | ||
ids_query: Tensor with the size of ``Q``, where ``Q <= n``. Each element is withing the range ``(0, n - 1)``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: withing
navigate over repo to find the same errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, removed everywhere
from oml.utils.misc_torch import pairwise_dist | ||
|
||
|
||
def batched_knn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this func has nothing with KNN. It's exact calculation of distance matrix with truncating. However, I don't know how to call it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the same as kNN from sklearn:
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(algorithm="auto", p=2)
knn.fit(features_galleries)
dists, ii_closest = knn.kneighbors(features_queries, n_neighbors=top_k, return_distance=True)
By the way, did not you mean aNN (approximate NN)? But it's not what we are doing here
oml/functional/knn.py
Outdated
gt_ids.extend([LongTensor(row.nonzero()).view(-1) for row in mask_gt_b]) # type: ignore | ||
|
||
distances_b[mask_to_ignore_b] = float("inf") | ||
distances[i : i + bs], retrieved_ids[i : i + bs] = torch.topk(distances_b, k=top_n, largest=False, sorted=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the second dimension for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay
oml/retrieval/retrieval_results.py
Outdated
|
||
""" | ||
if not isinstance(dataset, (IVisualizableDataset, IQueryGalleryDataset)): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
oml/retrieval/retrieval_results.py
Outdated
ii_gallery = dataset.get_gallery_ids() | ||
|
||
n_galleries_to_show = min(n_galleries_to_show, self.n_retrieved_items) | ||
n_gt_to_show = N_GT_SHOW_EMBEDDING_METRICS if (self.gt_ids is not None) else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about adding this as an argument with the default N_GT_SHOW_EMBEDDING_METRICS
? 2 is ok for exp logging, but might be not enough for developing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
oml/retrieval/retrieval_results.py
Outdated
n_rows, n_cols = len(query_ids), n_galleries_to_show + 1 + n_gt_to_show | ||
|
||
# iterate over queries | ||
for j, query_idx in enumerate(query_ids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
j
for rows and i
for cols, like [j,i]
? My entire life it is usually [i,j]
😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in mine as well :) swaped
fig.show() | ||
plt.close(fig=fig) | ||
|
||
print(rr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arrrr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RRRRR!
|
||
print(rr) | ||
|
||
assert True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, bro
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IT IS WHAT IT IS :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you know I like keeping assert True in the end
it shows that I did not forget to complete the test implementation :)
No description provided.