Skip to content

Commit de11fe0

Browse files
weiyangfbfacebook-github-bot
authored andcommitted
migrate PReLU to ATen (pytorch#11758)
Summary: - fixes pytorch#10723 - migrate PReLU to ATen and deprecate legacy PReLU - performance: CPU with weight.numel() = 1 ``` >>> m = nn.PReLU() >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 100 loops, best of 100: 9.43 ms per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 10 loops, best of 100: 24.4 ms per loop >>> m = nn.PReLU() >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 695 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 2.47 ms per loop ``` CPU with weight.numel() = channels ``` >>> m = nn.PReLU(100) >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 603 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 13.3 ms per loop >>> m = nn.PReLU(100) >>> x = torch.randn(100, 100, 100, requires_grad=True) >>> %timeit -r 100 y = m(x) 1000 loops, best of 100: 655 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 y.backward(retain_graph=True) 100 loops, best of 100: 2.45 ms per loop ``` CUDA with weight.numel() = 1 ``` >>> m = nn.PReLU().cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 10000 loops, best of 100: 187 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.01 ms per loop >>> m = nn.PReLU().cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 1000 loops, best of 100: 195 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.28 ms per loop ``` CUDA with weight.numel() = channel ``` >>> m = nn.PReLU(100).cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 1000 loops, best of 100: 174 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.27 ms per loop >>> m = nn.PReLU(100).cuda() >>> x = torch.randn(100, 100, 100, requires_grad=True).cuda() >>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize(); 10000 loops, best of 100: 181 µs per loop >>> y = m(x).sum() >>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize(); 100 loops, best of 100: 2.26 ms per loop ``` The huge performance regression in CPU when weight.numel() = 1 is addressed by replacing at::CPU_tensor_apply* with parallelized kernels. ezyang SsnL zou3519 soumith Pull Request resolved: pytorch#11758 Differential Revision: D9995799 Pulled By: weiyangfb fbshipit-source-id: d289937c78075f46a54dafbde92fab0cc4b5b86e
1 parent 89d56ae commit de11fe0

File tree

20 files changed

+550
-556
lines changed

20 files changed

+550
-556
lines changed

aten/src/ATen/core/Tensor.h

+2
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,8 @@ struct AT_API Tensor {
551551
Tensor & round_();
552552
Tensor relu() const;
553553
Tensor & relu_();
554+
Tensor prelu(const Tensor & weight) const;
555+
std::tuple<Tensor,Tensor> prelu_backward(const Tensor & grad_output, const Tensor & weight) const;
554556
Tensor hardshrink(Scalar lambd=0.5) const;
555557
Tensor hardshrink_backward(const Tensor & grad_out, Scalar lambd) const;
556558
Tensor rsqrt() const;

aten/src/ATen/core/TensorMethods.h

+6
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,12 @@ inline Tensor Tensor::relu() const {
935935
inline Tensor & Tensor::relu_() {
936936
return type().relu_(*this);
937937
}
938+
inline Tensor Tensor::prelu(const Tensor & weight) const {
939+
return type().prelu(*this, weight);
940+
}
941+
inline std::tuple<Tensor,Tensor> Tensor::prelu_backward(const Tensor & grad_output, const Tensor & weight) const {
942+
return type().prelu_backward(grad_output, *this, weight);
943+
}
938944
inline Tensor Tensor::hardshrink(Scalar lambd) const {
939945
return type().hardshrink(*this, lambd);
940946
}

aten/src/ATen/core/Type.h

+2
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ struct AT_API Type {
505505
virtual Tensor & round_(Tensor & self) const = 0;
506506
virtual Tensor relu(const Tensor & self) const = 0;
507507
virtual Tensor & relu_(Tensor & self) const = 0;
508+
virtual Tensor prelu(const Tensor & self, const Tensor & weight) const = 0;
509+
virtual std::tuple<Tensor,Tensor> prelu_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight) const = 0;
508510
virtual Tensor hardshrink(const Tensor & self, Scalar lambd) const = 0;
509511
virtual Tensor hardshrink_backward(const Tensor & grad_out, const Tensor & self, Scalar lambd) const = 0;
510512
virtual Tensor rsqrt(const Tensor & self) const = 0;

aten/src/ATen/native/Activation.cpp

+245
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ATen/NativeFunctions.h"
55
#include "ATen/core/Half.h"
66

7+
78
namespace at { namespace native {
89

910
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
@@ -43,6 +44,250 @@ Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Genera
4344
return at::rrelu_with_noise_(self, self.type().tensor(), lower, upper, training, generator);
4445
}
4546

47+
// -----------------------------------
48+
// prelu forward
49+
// -----------------------------------
50+
template <typename scalar_t>
51+
void inline prelu_cpu_kernel_share_weights(
52+
Tensor& result,
53+
const Tensor& input,
54+
const Tensor& weight) {
55+
56+
int64_t i;
57+
int64_t input_numel = input.numel();
58+
auto result_data = result.data<scalar_t>();
59+
auto input_data = input.data<scalar_t>();
60+
auto weight_val = weight.data<scalar_t>()[0];
61+
62+
#pragma omp parallel for private(i) if (input_numel > 1000)
63+
for (i = 0; i < input_numel; i++) {
64+
scalar_t input_data_val = input_data[i];
65+
// to allow for compiler optimization, here splitting into two lines:
66+
scalar_t r = (input_data_val > 0) ? scalar_t(1) : weight_val;
67+
result_data[i] = r * input_data_val;
68+
}
69+
}
70+
71+
template <typename scalar_t>
72+
void inline prelu_cpu_kernel_multi_weights(
73+
Tensor& result,
74+
const Tensor& input,
75+
const Tensor& weight,
76+
int64_t input_dim0_size,
77+
int64_t channel_size,
78+
int64_t input_stride0,
79+
int64_t input_stride1) {
80+
81+
int64_t i, j, k;
82+
int64_t input_numel = input.numel();
83+
scalar_t* result_data = result.data<scalar_t>();
84+
scalar_t* input_data = input.data<scalar_t>();
85+
scalar_t* weight_data = weight.data<scalar_t>();
86+
87+
#pragma omp parallel for private(i,j,k) if (input.numel() > 1000)
88+
for (i = 0; i < input_dim0_size; ++i) {
89+
int64_t offset = i * channel_size * input_stride1;
90+
scalar_t* n_input_data = input_data + offset;
91+
scalar_t* n_result_data = result_data + offset;
92+
for (j = 0; j < channel_size; ++j) {
93+
for (k = 0; k < input_stride1; ++k) {
94+
// to allow for compiler optimization, here splitting into two lines:
95+
scalar_t w = (n_input_data[k] > 0) ? scalar_t(1) : weight_data[j];
96+
n_result_data[k] = w * n_input_data[k];
97+
}
98+
n_input_data += input_stride1;
99+
n_result_data += input_stride1;
100+
}
101+
}
102+
}
103+
104+
Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
105+
auto input = self.contiguous();
106+
auto weight = weight_.contiguous();
107+
108+
AT_CHECK(input.is_contiguous());
109+
AT_CHECK(weight.is_contiguous());
110+
111+
int64_t weight_num = weight.numel();
112+
Tensor result = at::empty_like(input);
113+
auto strides = input.strides();
114+
115+
// case1: shared weight for all channels
116+
if (weight_num == 1) {
117+
AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] {
118+
prelu_cpu_kernel_share_weights<scalar_t>(result, input, weight);
119+
});
120+
}
121+
else { // case2: multiple weights, one for each channel
122+
int64_t input_ndim = input.dim();
123+
AT_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
124+
125+
int64_t channel_size = 1; // channel_size default to 1
126+
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
127+
128+
if (input_ndim > 1) {
129+
channel_size = input.size(1); // channel is the 2nd dim of input
130+
input_dim0_size = input.size(0);
131+
input_stride0 = strides[0];
132+
input_stride1 = strides[1];
133+
}
134+
AT_CHECK(channel_size == weight_num,
135+
"Mismatch of parameter numbers and input channel size. Found parameter numbers = %d, and channel size = %d.",
136+
weight_num, channel_size);
137+
138+
AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] {
139+
prelu_cpu_kernel_multi_weights<scalar_t>(
140+
result,
141+
input,
142+
weight,
143+
input_dim0_size,
144+
channel_size,
145+
input_stride0,
146+
input_stride1);
147+
});
148+
}
149+
return result;
150+
}
151+
152+
// -----------------------------------
153+
// prelu backward
154+
// -----------------------------------
155+
template <typename scalar_t>
156+
void inline prelu_cpu_backward_kernel_share_weights(
157+
const Tensor& input,
158+
const Tensor& weight,
159+
const Tensor& grad_out,
160+
Tensor& input_grad,
161+
Tensor& weight_grad) {
162+
163+
int64_t i;
164+
int64_t input_numel = input.numel();
165+
scalar_t sum = 0;
166+
auto input_data = input.data<scalar_t>();
167+
auto weight_val = weight.data<scalar_t>()[0];
168+
auto grad_out_data = grad_out.data<scalar_t>();
169+
auto input_grad_data = input_grad.data<scalar_t>();
170+
auto weight_grad_data = weight_grad.data<scalar_t>();
171+
172+
#pragma omp parallel for private(i) reduction(+:sum) if (input_numel > 1000)
173+
for (i = 0; i < input_numel; i++) {
174+
scalar_t input_data_val = input_data[i];
175+
scalar_t grad_out_data_val = grad_out_data[i];
176+
// to allow for compiler optimization, here splitting into two lines:
177+
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_val;
178+
input_grad_data[i] = w * grad_out_data_val;
179+
// to allow for compiler optimization, here splitting into two lines:
180+
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
181+
sum += mask * input_data_val * grad_out_data_val;
182+
}
183+
weight_grad_data[0] = sum;
184+
}
185+
186+
template <typename scalar_t>
187+
void inline prelu_cpu_backward_kernel_multi_weights(
188+
const Tensor& input,
189+
const Tensor& weight,
190+
const Tensor& grad_out,
191+
Tensor& input_grad,
192+
Tensor& weight_grad_collector,
193+
int64_t input_dim0_size,
194+
int64_t channel_size,
195+
int64_t input_stride0,
196+
int64_t input_stride1) {
197+
198+
int64_t i, j, k;
199+
int64_t input_numel = input.numel();
200+
auto input_data = input.data<scalar_t>();
201+
auto weight_data = weight.data<scalar_t>();
202+
auto grad_out_data = grad_out.data<scalar_t>();
203+
auto input_grad_data = input_grad.data<scalar_t>();
204+
auto weight_grad_collector_data = weight_grad_collector.data<scalar_t>();
205+
206+
#pragma omp parallel for private(i, j, k) if (input.numel() > 1000)
207+
for (i = 0; i < input_dim0_size; i++) {
208+
for (j = 0; j < channel_size; j++) {
209+
for (k = 0; k < input_stride1; k++) {
210+
int64_t pos = i * input_stride0 + j * input_stride1 + k;
211+
scalar_t weight_data_val = weight_data[j];
212+
scalar_t input_data_val = input_data[pos];
213+
scalar_t grad_out_data_val = grad_out_data[pos];
214+
// to allow for compiler optimization, here splitting into two lines:
215+
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_data_val;
216+
input_grad_data[pos] = w * grad_out_data_val;
217+
// to allow for compiler optimization, here splitting into two lines:
218+
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
219+
weight_grad_collector_data[pos] = mask * input_data_val * grad_out_data_val;
220+
}
221+
}
222+
}
223+
}
224+
225+
std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) {
226+
auto input = self.contiguous();
227+
auto grad_out = grad_out_.contiguous();
228+
auto weight = weight_.contiguous();
229+
230+
AT_CHECK(input.is_contiguous());
231+
AT_CHECK(grad_out.is_contiguous());
232+
AT_CHECK(weight.is_contiguous());
233+
234+
int64_t weight_num = weight.numel();
235+
auto strides = input.strides();
236+
auto dims = input.dim();
237+
238+
Tensor input_grad = at::empty_like(input);
239+
Tensor weight_grad = at::empty_like(weight);
240+
Tensor weight_grad_collector = at::empty_like(input);
241+
242+
// case1: shared parameter for all channels
243+
if (weight_num == 1) {
244+
AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] {
245+
prelu_cpu_backward_kernel_share_weights<scalar_t>(input, weight, grad_out, input_grad, weight_grad);
246+
});
247+
}
248+
else { // case2: multiple parameters, one for each channel
249+
int64_t input_ndim = input.dim();
250+
AT_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
251+
252+
int64_t channel_size = 1; // channel_size default to 1
253+
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
254+
255+
if (input_ndim > 1) {
256+
channel_size = input.size(1); // channel is the 2nd dim of input
257+
input_dim0_size = input.size(0);
258+
input_stride0 = strides[0];
259+
input_stride1 = strides[1];
260+
}
261+
AT_CHECK(channel_size == weight_num,
262+
"Mismatch of parameter numbers and input channel size. Found parameter numbers = %d, and channel size = %d.",
263+
weight_num, channel_size);
264+
265+
AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] {
266+
prelu_cpu_backward_kernel_multi_weights<scalar_t>(
267+
input,
268+
weight,
269+
grad_out,
270+
input_grad,
271+
weight_grad_collector,
272+
input_dim0_size,
273+
channel_size,
274+
input_stride0,
275+
input_stride1);
276+
});
277+
// update weight_grad
278+
std::vector<int64_t> reduce_dims;
279+
reduce_dims.push_back(0);
280+
if (dims > 2) {
281+
for(int64_t i = 2; i < dims; i++) reduce_dims.push_back(i);
282+
}
283+
weight_grad = weight_grad_collector.sum(reduce_dims);
284+
}
285+
return std::tuple<Tensor, Tensor>{input_grad, weight_grad};
286+
}
287+
288+
// -----------------------------------
289+
// hardshrink
290+
// -----------------------------------
46291
Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
47292
auto out_tensor = at::empty_like(self);
48293
AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] {

0 commit comments

Comments
 (0)