Skip to content

Commit f6df6ae

Browse files
bilgeacunfacebook-github-bot
authored andcommitted
Optimize MomentumSGDUpdate maximum block size and make it templated
Summary: Removing the maximum number of blocks limit from the operator and making the nesterov parameter templated to remove branching. Reviewed By: BIT-silence Differential Revision: D14567003 fbshipit-source-id: 394c2039ee214adc6ccd2e562e4e9563d307131f
1 parent e3da16a commit f6df6ae

File tree

1 file changed

+58
-24
lines changed

1 file changed

+58
-24
lines changed

caffe2/sgd/momentum_sgd_op_gpu.cu

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
namespace caffe2 {
66

7+
inline int CaffeGetBlocksSGD(const int N) {
8+
return std::max(
9+
(N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS,
10+
// Use at least 1 block, since CUDA does not allow empty block
11+
1);
12+
}
13+
template <bool nesterov>
714
__global__ void MomentumSGDKernel(
815
const int N,
916
const float* g,
@@ -12,27 +19,47 @@ __global__ void MomentumSGDKernel(
1219
float* nm,
1320
const float* lr,
1421
const float momentum,
15-
const bool nesterov,
22+
float* param);
23+
24+
template <>
25+
__global__ void MomentumSGDKernel<true>(
26+
const int N,
27+
const float* g,
28+
const float* m,
29+
float* ng,
30+
float* nm,
31+
const float* lr,
32+
const float momentum,
1633
float* param) {
1734
const float LR = lr[0];
18-
if (!nesterov) {
19-
CUDA_1D_KERNEL_LOOP(i, N) {
20-
const float adjusted_gradient = LR * g[i] + momentum * m[i];
21-
nm[i] = adjusted_gradient;
22-
ng[i] = adjusted_gradient;
23-
if (param) {
24-
param[i] -= adjusted_gradient;
25-
}
35+
CUDA_1D_KERNEL_LOOP(i, N) {
36+
const float mi = m[i];
37+
const float mi_new = momentum * mi + LR * g[i];
38+
nm[i] = mi_new;
39+
ng[i] = fmaf(momentum, mi_new - mi, mi_new);
40+
if (param != nullptr) {
41+
param[i] -= ng[i];
2642
}
27-
} else {
28-
CUDA_1D_KERNEL_LOOP(i, N) {
29-
const float mi = m[i];
30-
const float mi_new = momentum * mi + LR * g[i];
31-
nm[i] = mi_new;
32-
ng[i] = (1 + momentum) * mi_new - momentum * mi;
33-
if (param) {
34-
param[i] -= ng[i];
35-
}
43+
}
44+
}
45+
46+
template <>
47+
__global__ void MomentumSGDKernel<false>(
48+
const int N,
49+
const float* g,
50+
const float* m,
51+
float* ng,
52+
float* nm,
53+
const float* lr,
54+
const float momentum,
55+
float* param) {
56+
const float LR = lr[0];
57+
CUDA_1D_KERNEL_LOOP(i, N) {
58+
const float adjusted_gradient = LR * g[i] + momentum * m[i];
59+
nm[i] = adjusted_gradient;
60+
ng[i] = adjusted_gradient;
61+
if (param != nullptr) {
62+
param[i] -= adjusted_gradient;
3663
}
3764
}
3865
}
@@ -49,12 +76,19 @@ void momentum_sgd_update<CUDAContext>(
4976
const bool nesterov,
5077
float* param,
5178
CUDAContext* context) {
52-
MomentumSGDKernel<<<
53-
CAFFE_GET_BLOCKS(N),
54-
CAFFE_CUDA_NUM_THREADS,
55-
0,
56-
context->cuda_stream()>>>(
57-
N, g, m, ng, nm, lr, momentum, nesterov, param);
79+
if (nesterov) {
80+
MomentumSGDKernel<true>
81+
<<<CaffeGetBlocksSGD(N),
82+
CAFFE_CUDA_NUM_THREADS,
83+
0,
84+
context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
85+
} else {
86+
MomentumSGDKernel<false>
87+
<<<CaffeGetBlocksSGD(N),
88+
CAFFE_CUDA_NUM_THREADS,
89+
0,
90+
context->cuda_stream()>>>(N, g, m, ng, nm, lr, momentum, param);
91+
}
5892
}
5993
6094

0 commit comments

Comments
 (0)