11import torch
22import cppcuda_tutorial
3+ import time
4+
5+
6+ def trilinear_interpolation_py (feats , points ):
7+ """
8+ Inputs:
9+ feats: (N, 8, F)
10+ points: (N, 3) local coordinates in [-1, 1]
11+
12+ Outputs:
13+ feats_interp: (N, F)
14+ """
15+ u = (points [:, 0 :1 ]+ 1 )/ 2
16+ v = (points [:, 1 :2 ]+ 1 )/ 2
17+ w = (points [:, 2 :3 ]+ 1 )/ 2
18+ a = (1 - v )* (1 - w )
19+ b = (1 - v )* w
20+ c = v * (1 - w )
21+ d = 1 - a - b - c
22+
23+ feats_interp = (1 - u )* (a * feats [:, 0 ] +
24+ b * feats [:, 1 ] +
25+ c * feats [:, 2 ] +
26+ d * feats [:, 3 ]) + \
27+ u * (a * feats [:, 4 ] +
28+ b * feats [:, 5 ] +
29+ c * feats [:, 6 ] +
30+ d * feats [:, 7 ])
31+
32+ return feats_interp
333
434
535if __name__ == '__main__' :
36+ N = 65536 ; F = 256
37+ feats = torch .rand (N , 8 , F , device = 'cuda' ).requires_grad_ ()
38+ points = torch .rand (N , 3 , device = 'cuda' )* 2 - 1
639
7- feats = torch .ones (2 , device = 'cuda' )
8- points = torch .zeros (2 , device = 'cuda' )
40+ t = time .time ()
41+ out_cuda = cppcuda_tutorial .trilinear_interpolation (feats , points )
42+ torch .cuda .synchronize ()
43+ print (' cuda time' , time .time ()- t , 's' )
944
10- out = cppcuda_tutorial .trilinear_interpolation (feats , points )
45+ t = time .time ()
46+ out_py = trilinear_interpolation_py (feats , points )
47+ torch .cuda .synchronize ()
48+ print ('pytorch time' , time .time ()- t , 's' )
1149
12- print (out )
50+ print (torch . allclose ( out_py , out_cuda ) )
0 commit comments