|
4 | 4 | #include "ATen/NativeFunctions.h"
|
5 | 5 | #include "ATen/core/Half.h"
|
6 | 6 |
|
| 7 | + |
7 | 8 | namespace at { namespace native {
|
8 | 9 |
|
9 | 10 | static const double SELU_ALPHA = 1.6732632423543772848170429916717;
|
@@ -43,6 +44,250 @@ Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Genera
|
43 | 44 | return at::rrelu_with_noise_(self, self.type().tensor(), lower, upper, training, generator);
|
44 | 45 | }
|
45 | 46 |
|
| 47 | +// ----------------------------------- |
| 48 | +// prelu forward |
| 49 | +// ----------------------------------- |
| 50 | +template <typename scalar_t> |
| 51 | +void inline prelu_cpu_kernel_share_weights( |
| 52 | + Tensor& result, |
| 53 | + const Tensor& input, |
| 54 | + const Tensor& weight) { |
| 55 | + |
| 56 | + int64_t i; |
| 57 | + int64_t input_numel = input.numel(); |
| 58 | + auto result_data = result.data<scalar_t>(); |
| 59 | + auto input_data = input.data<scalar_t>(); |
| 60 | + auto weight_val = weight.data<scalar_t>()[0]; |
| 61 | + |
| 62 | + #pragma omp parallel for private(i) if (input_numel > 1000) |
| 63 | + for (i = 0; i < input_numel; i++) { |
| 64 | + scalar_t input_data_val = input_data[i]; |
| 65 | + // to allow for compiler optimization, here splitting into two lines: |
| 66 | + scalar_t r = (input_data_val > 0) ? scalar_t(1) : weight_val; |
| 67 | + result_data[i] = r * input_data_val; |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +template <typename scalar_t> |
| 72 | +void inline prelu_cpu_kernel_multi_weights( |
| 73 | + Tensor& result, |
| 74 | + const Tensor& input, |
| 75 | + const Tensor& weight, |
| 76 | + int64_t input_dim0_size, |
| 77 | + int64_t channel_size, |
| 78 | + int64_t input_stride0, |
| 79 | + int64_t input_stride1) { |
| 80 | + |
| 81 | + int64_t i, j, k; |
| 82 | + int64_t input_numel = input.numel(); |
| 83 | + scalar_t* result_data = result.data<scalar_t>(); |
| 84 | + scalar_t* input_data = input.data<scalar_t>(); |
| 85 | + scalar_t* weight_data = weight.data<scalar_t>(); |
| 86 | + |
| 87 | + #pragma omp parallel for private(i,j,k) if (input.numel() > 1000) |
| 88 | + for (i = 0; i < input_dim0_size; ++i) { |
| 89 | + int64_t offset = i * channel_size * input_stride1; |
| 90 | + scalar_t* n_input_data = input_data + offset; |
| 91 | + scalar_t* n_result_data = result_data + offset; |
| 92 | + for (j = 0; j < channel_size; ++j) { |
| 93 | + for (k = 0; k < input_stride1; ++k) { |
| 94 | + // to allow for compiler optimization, here splitting into two lines: |
| 95 | + scalar_t w = (n_input_data[k] > 0) ? scalar_t(1) : weight_data[j]; |
| 96 | + n_result_data[k] = w * n_input_data[k]; |
| 97 | + } |
| 98 | + n_input_data += input_stride1; |
| 99 | + n_result_data += input_stride1; |
| 100 | + } |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { |
| 105 | + auto input = self.contiguous(); |
| 106 | + auto weight = weight_.contiguous(); |
| 107 | + |
| 108 | + AT_CHECK(input.is_contiguous()); |
| 109 | + AT_CHECK(weight.is_contiguous()); |
| 110 | + |
| 111 | + int64_t weight_num = weight.numel(); |
| 112 | + Tensor result = at::empty_like(input); |
| 113 | + auto strides = input.strides(); |
| 114 | + |
| 115 | + // case1: shared weight for all channels |
| 116 | + if (weight_num == 1) { |
| 117 | + AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] { |
| 118 | + prelu_cpu_kernel_share_weights<scalar_t>(result, input, weight); |
| 119 | + }); |
| 120 | + } |
| 121 | + else { // case2: multiple weights, one for each channel |
| 122 | + int64_t input_ndim = input.dim(); |
| 123 | + AT_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); |
| 124 | + |
| 125 | + int64_t channel_size = 1; // channel_size default to 1 |
| 126 | + int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1; |
| 127 | + |
| 128 | + if (input_ndim > 1) { |
| 129 | + channel_size = input.size(1); // channel is the 2nd dim of input |
| 130 | + input_dim0_size = input.size(0); |
| 131 | + input_stride0 = strides[0]; |
| 132 | + input_stride1 = strides[1]; |
| 133 | + } |
| 134 | + AT_CHECK(channel_size == weight_num, |
| 135 | + "Mismatch of parameter numbers and input channel size. Found parameter numbers = %d, and channel size = %d.", |
| 136 | + weight_num, channel_size); |
| 137 | + |
| 138 | + AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_cpu", [&] { |
| 139 | + prelu_cpu_kernel_multi_weights<scalar_t>( |
| 140 | + result, |
| 141 | + input, |
| 142 | + weight, |
| 143 | + input_dim0_size, |
| 144 | + channel_size, |
| 145 | + input_stride0, |
| 146 | + input_stride1); |
| 147 | + }); |
| 148 | + } |
| 149 | + return result; |
| 150 | +} |
| 151 | + |
| 152 | +// ----------------------------------- |
| 153 | +// prelu backward |
| 154 | +// ----------------------------------- |
| 155 | +template <typename scalar_t> |
| 156 | +void inline prelu_cpu_backward_kernel_share_weights( |
| 157 | + const Tensor& input, |
| 158 | + const Tensor& weight, |
| 159 | + const Tensor& grad_out, |
| 160 | + Tensor& input_grad, |
| 161 | + Tensor& weight_grad) { |
| 162 | + |
| 163 | + int64_t i; |
| 164 | + int64_t input_numel = input.numel(); |
| 165 | + scalar_t sum = 0; |
| 166 | + auto input_data = input.data<scalar_t>(); |
| 167 | + auto weight_val = weight.data<scalar_t>()[0]; |
| 168 | + auto grad_out_data = grad_out.data<scalar_t>(); |
| 169 | + auto input_grad_data = input_grad.data<scalar_t>(); |
| 170 | + auto weight_grad_data = weight_grad.data<scalar_t>(); |
| 171 | + |
| 172 | + #pragma omp parallel for private(i) reduction(+:sum) if (input_numel > 1000) |
| 173 | + for (i = 0; i < input_numel; i++) { |
| 174 | + scalar_t input_data_val = input_data[i]; |
| 175 | + scalar_t grad_out_data_val = grad_out_data[i]; |
| 176 | + // to allow for compiler optimization, here splitting into two lines: |
| 177 | + scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_val; |
| 178 | + input_grad_data[i] = w * grad_out_data_val; |
| 179 | + // to allow for compiler optimization, here splitting into two lines: |
| 180 | + scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1); |
| 181 | + sum += mask * input_data_val * grad_out_data_val; |
| 182 | + } |
| 183 | + weight_grad_data[0] = sum; |
| 184 | +} |
| 185 | + |
| 186 | +template <typename scalar_t> |
| 187 | +void inline prelu_cpu_backward_kernel_multi_weights( |
| 188 | + const Tensor& input, |
| 189 | + const Tensor& weight, |
| 190 | + const Tensor& grad_out, |
| 191 | + Tensor& input_grad, |
| 192 | + Tensor& weight_grad_collector, |
| 193 | + int64_t input_dim0_size, |
| 194 | + int64_t channel_size, |
| 195 | + int64_t input_stride0, |
| 196 | + int64_t input_stride1) { |
| 197 | + |
| 198 | + int64_t i, j, k; |
| 199 | + int64_t input_numel = input.numel(); |
| 200 | + auto input_data = input.data<scalar_t>(); |
| 201 | + auto weight_data = weight.data<scalar_t>(); |
| 202 | + auto grad_out_data = grad_out.data<scalar_t>(); |
| 203 | + auto input_grad_data = input_grad.data<scalar_t>(); |
| 204 | + auto weight_grad_collector_data = weight_grad_collector.data<scalar_t>(); |
| 205 | + |
| 206 | + #pragma omp parallel for private(i, j, k) if (input.numel() > 1000) |
| 207 | + for (i = 0; i < input_dim0_size; i++) { |
| 208 | + for (j = 0; j < channel_size; j++) { |
| 209 | + for (k = 0; k < input_stride1; k++) { |
| 210 | + int64_t pos = i * input_stride0 + j * input_stride1 + k; |
| 211 | + scalar_t weight_data_val = weight_data[j]; |
| 212 | + scalar_t input_data_val = input_data[pos]; |
| 213 | + scalar_t grad_out_data_val = grad_out_data[pos]; |
| 214 | + // to allow for compiler optimization, here splitting into two lines: |
| 215 | + scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_data_val; |
| 216 | + input_grad_data[pos] = w * grad_out_data_val; |
| 217 | + // to allow for compiler optimization, here splitting into two lines: |
| 218 | + scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1); |
| 219 | + weight_grad_collector_data[pos] = mask * input_data_val * grad_out_data_val; |
| 220 | + } |
| 221 | + } |
| 222 | + } |
| 223 | +} |
| 224 | + |
| 225 | +std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) { |
| 226 | + auto input = self.contiguous(); |
| 227 | + auto grad_out = grad_out_.contiguous(); |
| 228 | + auto weight = weight_.contiguous(); |
| 229 | + |
| 230 | + AT_CHECK(input.is_contiguous()); |
| 231 | + AT_CHECK(grad_out.is_contiguous()); |
| 232 | + AT_CHECK(weight.is_contiguous()); |
| 233 | + |
| 234 | + int64_t weight_num = weight.numel(); |
| 235 | + auto strides = input.strides(); |
| 236 | + auto dims = input.dim(); |
| 237 | + |
| 238 | + Tensor input_grad = at::empty_like(input); |
| 239 | + Tensor weight_grad = at::empty_like(weight); |
| 240 | + Tensor weight_grad_collector = at::empty_like(input); |
| 241 | + |
| 242 | + // case1: shared parameter for all channels |
| 243 | + if (weight_num == 1) { |
| 244 | + AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] { |
| 245 | + prelu_cpu_backward_kernel_share_weights<scalar_t>(input, weight, grad_out, input_grad, weight_grad); |
| 246 | + }); |
| 247 | + } |
| 248 | + else { // case2: multiple parameters, one for each channel |
| 249 | + int64_t input_ndim = input.dim(); |
| 250 | + AT_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); |
| 251 | + |
| 252 | + int64_t channel_size = 1; // channel_size default to 1 |
| 253 | + int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1; |
| 254 | + |
| 255 | + if (input_ndim > 1) { |
| 256 | + channel_size = input.size(1); // channel is the 2nd dim of input |
| 257 | + input_dim0_size = input.size(0); |
| 258 | + input_stride0 = strides[0]; |
| 259 | + input_stride1 = strides[1]; |
| 260 | + } |
| 261 | + AT_CHECK(channel_size == weight_num, |
| 262 | + "Mismatch of parameter numbers and input channel size. Found parameter numbers = %d, and channel size = %d.", |
| 263 | + weight_num, channel_size); |
| 264 | + |
| 265 | + AT_DISPATCH_FLOATING_TYPES(input.type(), "prelu_backward_cpu", [&] { |
| 266 | + prelu_cpu_backward_kernel_multi_weights<scalar_t>( |
| 267 | + input, |
| 268 | + weight, |
| 269 | + grad_out, |
| 270 | + input_grad, |
| 271 | + weight_grad_collector, |
| 272 | + input_dim0_size, |
| 273 | + channel_size, |
| 274 | + input_stride0, |
| 275 | + input_stride1); |
| 276 | + }); |
| 277 | + // update weight_grad |
| 278 | + std::vector<int64_t> reduce_dims; |
| 279 | + reduce_dims.push_back(0); |
| 280 | + if (dims > 2) { |
| 281 | + for(int64_t i = 2; i < dims; i++) reduce_dims.push_back(i); |
| 282 | + } |
| 283 | + weight_grad = weight_grad_collector.sum(reduce_dims); |
| 284 | + } |
| 285 | + return std::tuple<Tensor, Tensor>{input_grad, weight_grad}; |
| 286 | +} |
| 287 | + |
| 288 | +// ----------------------------------- |
| 289 | +// hardshrink |
| 290 | +// ----------------------------------- |
46 | 291 | Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
|
47 | 292 | auto out_tensor = at::empty_like(self);
|
48 | 293 | AT_DISPATCH_FLOATING_TYPES(self.type(), "hardshrink_cpu", [&] {
|
|
0 commit comments