Skip to content

Commit 2d7c663

Browse files
author
kuangliu
committed
Simplify one_hot_embedding
1 parent b262983 commit 2d7c663

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,8 @@ def one_hot_embedding(labels, num_classes):
223223
Returns:
224224
(tensor) encoded labels, sized [N,#classes].
225225
'''
226-
N = labels.size(0)
227-
D = num_classes
228-
y = torch.zeros(N,D)
229-
y[torch.arange(0,N).long(),labels] = 1
230-
return y
226+
y = torch.eye(num_classes) # [D,D]
227+
return y[labels] # [N,D]
231228

232229
def msr_init(net):
233230
'''Initialize layer parameters.'''

0 commit comments

Comments
 (0)