Skip to content

Recall at K #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset/carDB/
cars_annos.mat
logs/
car_ims.tgz
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,30 @@ $ python main.py --gpu=0 \
--scale=12.0 --check_epoch=5 \
--ps_alpha=0.40 --ps_mu=1.0
```
#### Proxy-Anchor loss with CARS196

```bash
$ python main.py --gpu=0 \
--save_path=./logs/CARS196_Proxy_Anchor \
--data=./dataset/carDB --data_name=cars196 \
--dim=512 --batch_size=128 --epochs=130 \
--freeze_BN --loss=Proxy_Anchor \
--decay_step=50 --decay_stop=50 --n_instance=1 \
--scale=23.0 --check_epoch=5
```

#### PS + Proxy-Anchor loss with CARS196

```bash
$ python main.py --gpu=0 \
--save_path=./logs/CARS196_PS_Proxy_Anchor \
--data=./dataset/carDB --data_name=cars196 \
--dim=512 --batch_size=128 --epochs=130 \
--freeze_BN --loss=Proxy_Anchor \
--decay_step=50 --decay_stop=50 --n_instance=1 \
--scale=23.0 --check_epoch=5 \
--ps_alpha=0.40 --ps_mu=1.0
```
### Check Test Results
```
$ tensorboard --logdir=logs --port=10000
Expand Down
98 changes: 98 additions & 0 deletions cosCentroids.csv

Large diffs are not rendered by default.

98 changes: 98 additions & 0 deletions cosProxy.csv

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions dataset/prepare_cars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import shutil

base_path = 'dataset'
base_path = './cars'
trainPrefix = os.path.join(base_path, 'carDB/train/')
testPrefix = os.path.join(base_path, 'carDB/test/')
for lines in open(os.path.join(base_path, 'cars_annos.txt')):
Expand All @@ -25,7 +25,7 @@
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)

try:
os.rmdir(os.path.join(base_path, 'car_ims'))
except Exception as e:
print (e)
#download the tar of all images combined from https://ai.stanford.edu/~jkrause/cars/car_dataset.html
#download the txt file from the same link
#all of the above inside cars folder
# run this file from outside the cars folder
28 changes: 28 additions & 0 deletions dataset/prepare_cub200.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import shutil

base_path = './cub200'
trainPrefix = os.path.join(base_path, 'cub200DB/train/')
testPrefix = os.path.join(base_path, 'cub200DB/test/')
for lines in open(os.path.join(base_path, 'lists/files.txt')):
line = lines.strip().split('.')
classInd = int(line[0])
fname = lines.split('/')[1].split('\n')[0] #the name of the file we want
print(fname)
file_path = os.path.join(base_path + '/images', lines.split('\n')[0])
print(file_path)
if classInd <= 100:
ddr = trainPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)
else:
ddr = testPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)


#Download images and lists from http://www.vision.caltech.edu/visipedia/CUB-200.html
#place it into cub200 folder
#run prepare_cub200 from outside the cub200 folder
86 changes: 86 additions & 0 deletions dataset/prepare_grocery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
'''
proxy-synthesis
Copyright (c) 2021-present NAVER Corp.
Apache License v2.0
'''
import os
import shutil

base_path = './grocery'
trainPrefix = os.path.join(base_path, 'groceryDB/train/')
testPrefix = os.path.join(base_path, 'groceryDB/test/')
for lines in open(os.path.join(base_path, 'train.txt')):
lines = lines.strip().split(',')
classInd = int(lines[1])
classInd = classInd + 1 #1 indexing it
print(lines[0])
try:
fname = lines[0].split('/')[4]
except IndexError:
fname = lines[0].split('/')[3]
file_path = os.path.join(base_path, lines[0])
if classInd <= 41:
ddr = trainPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'train' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)
else:
ddr = testPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'train' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)

for lines in open(os.path.join(base_path, 'test.txt')):
lines = lines.strip().split(',')
classInd = int(lines[1])
classInd = classInd + 1 #1 indexing it
print(lines[0])
try:
fname = lines[0].split('/')[4]
except IndexError:
fname = lines[0].split('/')[3]
file_path = os.path.join(base_path, lines[0])
if classInd <= 41:
ddr = trainPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'test' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)
else:
ddr = testPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'test' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)


for lines in open(os.path.join(base_path, 'val.txt')):
lines = lines.strip().split(',')
classInd = int(lines[1])
classInd = classInd + 1 #1 indexing it
print(lines[0])
try:
fname = lines[0].split('/')[4]
except IndexError:
fname = lines[0].split('/')[3]
file_path = os.path.join(base_path, lines[0])
if classInd <= 41:
ddr = trainPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'val' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)
else:
ddr = testPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
newfname = fname.split('.')[0] + 'val' + '.jpg'
shutil.move(file_path, ddr + '/' + newfname)

#the dataset was divided into 81 fine grained classes (what we need) and 42 course grained classes (don't need)
#zero indexed, so made it one indexed
#train, test and val where separated in the dataset and had similar naming. so combined it and changed names.
#data can be downloaded from https://github.com/marcusklasson/GroceryStoreDataset

38 changes: 38 additions & 0 deletions dataset/prepare_vegfru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import shutil

base_path = './data/veg_fru'
trainPrefix = os.path.join(base_path, 'vegfruDB/train/')
testPrefix = os.path.join(base_path, 'vegfruDB/test/')
valPrefix = os.path.join(base_path, 'vegfruDB/val/')

for lines in open(os.path.join(base_path, 'vegfru_list/vegfru_train.txt')):
lines = lines.strip().split(' ')
classInd = int(lines[1])
fname = lines[0].split('/')[2]
file_path = os.path.join(base_path, lines[0])
ddr = trainPrefix + str(classInd)
if not os.path.exists(ddr):
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)


for lines in open(os.path.join(base_path, 'vegfru_list/vegfru_test.txt')):
lines = lines.strip().split(' ')
classInd = int(lines[1])
fname = lines[0].split('/')[2]
file_path = os.path.join(base_path, lines[0])
ddr = testPrefix + lines[1]
if not os.path.exists(ddr):
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)

for lines in open(os.path.join(base_path, 'vegfru_list/vegfru_val.txt')):
lines = lines.strip().split(' ')
classInd = int(lines[1])
fname = lines[0].split('/')[2]
file_path = os.path.join(base_path, lines[0])
ddr = valPrefix + lines[1]
if not os.path.exists(ddr):
os.makedirs(ddr)
shutil.move(file_path, ddr + '/' + fname)
1 change: 1 addition & 0 deletions loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
'''
from .proxy_losses import Norm_SoftMax
from .proxy_losses import Proxy_NCA
from .proxy_losses import Proxy_Anchor
134 changes: 134 additions & 0 deletions loss/proxy_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

def proxy_synthesis(input_l2, proxy_l2, target, ps_alpha, ps_mu):
'''
Expand Down Expand Up @@ -102,3 +107,132 @@ def forward(self, input, target):
loss = torch.mean(torch.sum(-pos_target * F.log_softmax(-dist_mat, -1), -1))

return loss


def binarize(T, nb_classes):
T = T.cpu().numpy()
import sklearn.preprocessing
T = sklearn.preprocessing.label_binarize(
T, classes = range(0, nb_classes)
)
T = torch.FloatTensor(T).cuda()
return T

def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-12)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output

class Proxy_Anchor(torch.nn.Module):
def __init__(self, sz_embed, nb_classes, mrg = 0.1, alpha = 32, ps_mu=0.0, ps_alpha=0.0):
torch.nn.Module.__init__(self)
# Proxy Anchor Initialization
self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
nn.init.kaiming_normal_(self.proxies, mode='fan_out')

self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.mrg = mrg
self.alpha = alpha
self.ps_mu = ps_mu
self.ps_alpha = ps_alpha

def forward(self, X, T):
P = self.proxies
input_l2 = l2_norm(X)
proxy_l2 = l2_norm(P)
if self.ps_mu > 0.0:
input_l2, proxy_l2, target = proxy_synthesis(input_l2, proxy_l2, T,
self.ps_alpha, self.ps_mu)

cos = F.linear(input_l2, proxy_l2) # Calcluate cosine similarity
P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
N_one_hot = 1 - P_one_hot

pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
neg_exp = torch.exp(self.alpha * (cos + self.mrg))

with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch
num_valid_proxies = len(with_pos_proxies) # The number of positive proxies

P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0)
N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)

pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
loss = pos_term + neg_term
return loss


class Proxy_Anchor_Compare(torch.nn.Module):
def __init__(self, sz_embed, nb_classes, mrg = 0.1, alpha = 32, ps_mu=0.0, ps_alpha=0.0):
torch.nn.Module.__init__(self)
# Proxy Anchor Initialization
self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda())
nn.init.kaiming_normal_(self.proxies, mode='fan_out')

self.nb_classes = nb_classes
self.sz_embed = sz_embed
self.mrg = mrg
self.alpha = alpha
self.ps_mu = ps_mu
self.ps_alpha = ps_alpha

def forward(self, X, T):
P = self.proxies
input_l2 = l2_norm(X)
proxy_l2 = l2_norm(P)
if self.ps_mu > 0.0:
input_l2, proxy_l2, target = proxy_synthesis(input_l2, proxy_l2, T,
self.ps_alpha, self.ps_mu)

cos = F.linear(input_l2, proxy_l2) # Calcluate cosine similarity
P_one_hot = binarize(T = T, nb_classes = self.nb_classes)
N_one_hot = 1 - P_one_hot

pos_exp = torch.exp(-self.alpha * (cos - self.mrg))
neg_exp = torch.exp(self.alpha * (cos + self.mrg))

with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch
num_valid_proxies = len(with_pos_proxies) # The number of positive proxies

P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0)
N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0)

pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies
neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes
loss = pos_term + neg_term

####Centroid

centroid = []
# proxy = []

for i in range(self.nb_classes):

inputsFromSameClass = T == i
val = torch.mean(input_l2[inputsFromSameClass], dim=0)
if torch.all(torch.isnan(val)):
val = torch.zeros((512), device="cuda:0")
centroid.append(val) ##rows that map to same class
# proxy.append(torch.mean(proxy_l2[inputsFromSameClass])) ##as proxy of same class is same mean does not make a diff

centroid = torch.stack(centroid)
# proxy = torch.stack(proxy)

# ###Centroid similarity
cosCentroids = F.linear(centroid, centroid)
cosProxy = F.linear(proxy_l2, proxy_l2)

###Find the closest
cosCentroids = torch.argsort(cosCentroids, dim=1)
cosProxy = torch.argsort(cosProxy, dim=1)
# ###check
np.savetxt("cosCentroids.csv", cosCentroids.cpu().detach().numpy(), delimiter=",")
np.savetxt("cosProxy.csv", cosProxy.cpu().detach().numpy(), delimiter=",")

return loss
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def main():
test_loader = torch.utils.data.DataLoader(
test_image,
batch_size=128, shuffle=False,
num_workers=args.workers, pin_memory=True)
num_workers=args.workers, pin_memory=False)

if args.data_name.lower() == 'inshop':
image_info = np.array(test_image.imgs)
Expand Down Expand Up @@ -216,6 +216,8 @@ def main():
elif args.loss.lower() == 'Proxy_NCA'.lower():
criterion = loss.Proxy_NCA(args.dim, args.C, scale=args.scale,
ps_mu=args.ps_mu, ps_alpha=args.ps_alpha).cuda()
elif args.loss.lower() == 'Proxy_Anchor'.lower():
criterion = loss.Proxy_Anchor(args.dim, args.C).cuda()
else:
raise ValueError("{} is not supported loss name".format(args.loss))

Expand Down Expand Up @@ -302,6 +304,7 @@ def main():

for epoch in range(args.start_epoch, args.epochs):
epoch += 1
torch.cuda.empty_cache()
print('Training in Epoch[{}]'.format(epoch))
adjust_learning_rate(optimizer, epoch, args)

Expand Down