4
4
5
5
namespace caffe2 {
6
6
7
+ inline int CaffeGetBlocksSGD (const int N) {
8
+ return std::max (
9
+ (N + CAFFE_CUDA_NUM_THREADS - 1 ) / CAFFE_CUDA_NUM_THREADS,
10
+ // Use at least 1 block, since CUDA does not allow empty block
11
+ 1 );
12
+ }
13
+ template <bool nesterov>
7
14
__global__ void MomentumSGDKernel (
8
15
const int N,
9
16
const float * g,
@@ -12,27 +19,47 @@ __global__ void MomentumSGDKernel(
12
19
float * nm,
13
20
const float * lr,
14
21
const float momentum,
15
- const bool nesterov,
22
+ float * param);
23
+
24
+ template <>
25
+ __global__ void MomentumSGDKernel<true >(
26
+ const int N,
27
+ const float * g,
28
+ const float * m,
29
+ float * ng,
30
+ float * nm,
31
+ const float * lr,
32
+ const float momentum,
16
33
float * param) {
17
34
const float LR = lr[0 ];
18
- if (!nesterov) {
19
- CUDA_1D_KERNEL_LOOP (i, N) {
20
- const float adjusted_gradient = LR * g[i] + momentum * m[i];
21
- nm[i] = adjusted_gradient;
22
- ng[i] = adjusted_gradient;
23
- if (param) {
24
- param[i] -= adjusted_gradient;
25
- }
35
+ CUDA_1D_KERNEL_LOOP (i, N) {
36
+ const float mi = m[i];
37
+ const float mi_new = momentum * mi + LR * g[i];
38
+ nm[i] = mi_new;
39
+ ng[i] = fmaf (momentum, mi_new - mi, mi_new);
40
+ if (param != nullptr ) {
41
+ param[i] -= ng[i];
26
42
}
27
- } else {
28
- CUDA_1D_KERNEL_LOOP (i, N) {
29
- const float mi = m[i];
30
- const float mi_new = momentum * mi + LR * g[i];
31
- nm[i] = mi_new;
32
- ng[i] = (1 + momentum) * mi_new - momentum * mi;
33
- if (param) {
34
- param[i] -= ng[i];
35
- }
43
+ }
44
+ }
45
+
46
+ template <>
47
+ __global__ void MomentumSGDKernel<false >(
48
+ const int N,
49
+ const float * g,
50
+ const float * m,
51
+ float * ng,
52
+ float * nm,
53
+ const float * lr,
54
+ const float momentum,
55
+ float * param) {
56
+ const float LR = lr[0 ];
57
+ CUDA_1D_KERNEL_LOOP (i, N) {
58
+ const float adjusted_gradient = LR * g[i] + momentum * m[i];
59
+ nm[i] = adjusted_gradient;
60
+ ng[i] = adjusted_gradient;
61
+ if (param != nullptr ) {
62
+ param[i] -= adjusted_gradient;
36
63
}
37
64
}
38
65
}
@@ -49,12 +76,19 @@ void momentum_sgd_update<CUDAContext>(
49
76
const bool nesterov,
50
77
float * param,
51
78
CUDAContext* context) {
52
- MomentumSGDKernel<<<
53
- CAFFE_GET_BLOCKS (N),
54
- CAFFE_CUDA_NUM_THREADS,
55
- 0,
56
- context->cuda_stream()>>>(
57
- N, g, m, ng, nm, lr, momentum, nesterov, param);
79
+ if (nesterov) {
80
+ MomentumSGDKernel<true >
81
+ <<<CaffeGetBlocksSGD(N),
82
+ CAFFE_CUDA_NUM_THREADS,
83
+ 0 ,
84
+ context->cuda_stream ()>>>(N, g, m, ng, nm, lr, momentum, param);
85
+ } else {
86
+ MomentumSGDKernel<false >
87
+ <<<CaffeGetBlocksSGD(N),
88
+ CAFFE_CUDA_NUM_THREADS,
89
+ 0 ,
90
+ context->cuda_stream ()>>>(N, g, m, ng, nm, lr, momentum, param);
91
+ }
58
92
}
59
93
60
94
0 commit comments