|
| 1 | +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#if GOOGLE_CUDA |
| 17 | + |
| 18 | +#define EIGEN_USE_GPU |
| 19 | + |
| 20 | +#include <thrust/device_ptr.h> |
| 21 | +#include <thrust/fill.h> |
| 22 | +#include <thrust/sort.h> |
| 23 | + |
| 24 | +#include "tensorflow/core/framework/op_kernel.h" |
| 25 | +#include "tensorflow/core/util/gpu_kernel_helper.h" |
| 26 | +#include "tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_ops.h" |
| 27 | + |
| 28 | +constexpr int MAX_THREADS_PER_BLOCK = 1024; |
| 29 | + |
| 30 | +namespace tensorflow { |
| 31 | +namespace addons { |
| 32 | +namespace functor { |
| 33 | + |
| 34 | +typedef Eigen::GpuDevice GPUDevice; |
| 35 | + |
| 36 | +template <typename Tindices, const int kThreadsPerBlock> |
| 37 | +__global__ void PrepTempArraysKernel( |
| 38 | + const Tindices *__restrict__ indices, Tindices *__restrict__ sortedIndices, |
| 39 | + Tindices *__restrict__ sortedIndicesCounter, const int indices_size) { |
| 40 | + const int arrayIdx = (blockIdx.x * kThreadsPerBlock) + threadIdx.x; |
| 41 | + if (arrayIdx < |
| 42 | + indices_size) { // Make sure we don't run off the end of the actual array |
| 43 | + sortedIndices[arrayIdx] = indices[arrayIdx]; |
| 44 | + sortedIndicesCounter[arrayIdx] = arrayIdx; |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// Define the CUDA kernel. |
| 49 | +template <typename T, typename Tindices, const int kThreadsPerBlock> |
| 50 | +__global__ void EmbeddingBagWeightsGradKernel( |
| 51 | + const int value_dim, const Tindices *__restrict__ indices, |
| 52 | + const T *__restrict__ values, const T *__restrict__ dloss, |
| 53 | + T *__restrict__ weights_grad) { |
| 54 | + const int sample_idx = blockIdx.x; |
| 55 | + const int bag_idx = blockIdx.y; |
| 56 | + const int bag_dim = gridDim.y; |
| 57 | + const int valueBaseIdx = |
| 58 | + indices[(sample_idx * bag_dim) + bag_idx] * value_dim; |
| 59 | + const int dlossBaseIdx = sample_idx * value_dim; |
| 60 | + // Use a full-precision accumulator even for half-precision inputs |
| 61 | + float partialDotProduct = 0.0f; |
| 62 | + for (int i = threadIdx.x; i < value_dim; |
| 63 | + i += blockDim.x) // Note that some threads may stop one iteration |
| 64 | + // earlier if the block straddles the end of the array |
| 65 | + { |
| 66 | + partialDotProduct += |
| 67 | + static_cast<float>(values[valueBaseIdx + i] * dloss[dlossBaseIdx + i]); |
| 68 | + } |
| 69 | + unsigned activeMask = 0xffffffff; |
| 70 | +#pragma unroll |
| 71 | + for (int offset = kThreadsPerBlock / 2; offset > 0; offset /= 2) { |
| 72 | + partialDotProduct += |
| 73 | + __shfl_down_sync(activeMask, partialDotProduct, offset); |
| 74 | + } |
| 75 | + // Thread 0 now has the full dot product |
| 76 | + if (threadIdx.x == 0) { |
| 77 | + weights_grad[(sample_idx * bag_dim) + bag_idx] = |
| 78 | + static_cast<T>(partialDotProduct); |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +template <typename T, typename Tindices> |
| 83 | +__global__ void EmbeddingBagValuesGradKernel( |
| 84 | + const int value_dim, const int bag_dim, |
| 85 | + const Tindices *__restrict__ sortedIndices, |
| 86 | + const Tindices *__restrict__ counter, const T *__restrict__ values, |
| 87 | + const T *__restrict__ weights, const T *__restrict__ dloss, |
| 88 | + T *__restrict__ values_grad) { |
| 89 | + const int startIdx = blockIdx.x; |
| 90 | + const int chunk = blockIdx.y; |
| 91 | + const int kThreadsPerBlock = blockDim.x; |
| 92 | + const int featureIdx = threadIdx.x + (chunk * kThreadsPerBlock); |
| 93 | + // The core problem here is that we want to avoid parallel writes to the |
| 94 | + // same element of the grads. We avoid that by pre-sorting a copy of the |
| 95 | + // indices tensor, and also co-sorting a 'counter' array so that we still know |
| 96 | + // which element of the incoming gradient tensor corresponds to each. Then, we |
| 97 | + // take the slightly lazy approach of spinning up a warp for each element of |
| 98 | + // the indices array, but having each warp check the previous element before |
| 99 | + // it starts. If the two elements are the same, then the warp immediately |
| 100 | + // returns without doing anything. If not, then the warp iterates forward and |
| 101 | + // accumulates gradient until it hits a different index element, at which |
| 102 | + // point it writes the accumulated value and returns. This ensures that each |
| 103 | + // row of the values grad tensor is handled by one and exactly one warp. |
| 104 | + const int valuesIdx = ldg(sortedIndices + startIdx); |
| 105 | + if (startIdx > 0) { |
| 106 | + const int prevIdx = ldg(sortedIndices + startIdx - 1); |
| 107 | + if (prevIdx == valuesIdx) { |
| 108 | + return; // Another block is handling this index, exit |
| 109 | + } |
| 110 | + } |
| 111 | + int endIdx = startIdx; |
| 112 | + while (endIdx < gridDim.x - 1) // Don't run off the end of the array |
| 113 | + { |
| 114 | + int nextIdx = endIdx + 1; |
| 115 | + int nextValuesIdx = ldg(sortedIndices + nextIdx); |
| 116 | + if (nextValuesIdx == valuesIdx) { |
| 117 | + endIdx += 1; |
| 118 | + } else { |
| 119 | + break; |
| 120 | + } |
| 121 | + } |
| 122 | + if (featureIdx < value_dim) // Don't run off the end of the row |
| 123 | + { |
| 124 | + const int outputOffset = (valuesIdx * value_dim) + featureIdx; |
| 125 | + float accum = 0.0f; // Full precision even if the inputs aren't |
| 126 | + |
| 127 | + for (int currentIdx = startIdx; currentIdx <= endIdx; ++currentIdx) { |
| 128 | + int originalIdxPosition = ldg(counter + currentIdx); |
| 129 | + T weight = weights[originalIdxPosition]; |
| 130 | + // The floor division on this line is correct and intentional |
| 131 | + T featureDloss = |
| 132 | + ldg(dloss + (originalIdxPosition / bag_dim) + featureIdx); |
| 133 | + accum += static_cast<float>(weight * featureDloss); |
| 134 | + } |
| 135 | + values_grad[outputOffset] = static_cast<T>(accum); |
| 136 | + } |
| 137 | +} |
| 138 | + |
| 139 | +// Define the GPU implementation that launches the CUDA kernel. |
| 140 | +template <typename T, typename Tindices> |
| 141 | +struct EmbeddingBagBackwardFunctor<GPUDevice, T, Tindices> { |
| 142 | + // indices should remain unchanged, but thrust complains if it's a const |
| 143 | + // pointer |
| 144 | + void operator()(const GPUDevice &d, |
| 145 | + typename TTypes<Tindices, 2>::ConstTensor indices, |
| 146 | + typename TTypes<T, 2>::ConstTensor params, |
| 147 | + typename TTypes<T, 2>::ConstTensor weights, |
| 148 | + typename TTypes<T, 2>::ConstTensor grads, |
| 149 | + typename TTypes<T, 2>::Tensor params_grads, |
| 150 | + typename TTypes<T, 2>::Tensor weights_grads, |
| 151 | + Combiner combiner, OpKernelContext *context) { |
| 152 | + // I copy-pasted this bit from histogram_op_gpu.cu.cc and I sure hope it |
| 153 | + // works |
| 154 | + tensorflow::AllocatorAttributes gpu_allocator; |
| 155 | + gpu_allocator.set_on_host(false); |
| 156 | + gpu_allocator.set_gpu_compatible(true); |
| 157 | + |
| 158 | + Tensor sortedIndicesTensor; |
| 159 | + Tensor sortedIndicesCounterTensor; |
| 160 | + |
| 161 | + OP_REQUIRES_OK(context, |
| 162 | + context->allocate_temp(DataTypeToEnum<Tindices>::value, |
| 163 | + TensorShape({indices.size()}), |
| 164 | + &sortedIndicesTensor, gpu_allocator)); |
| 165 | + OP_REQUIRES_OK(context, context->allocate_temp( |
| 166 | + DataTypeToEnum<Tindices>::value, |
| 167 | + TensorShape({indices.size()}), |
| 168 | + &sortedIndicesCounterTensor, gpu_allocator)); |
| 169 | + auto sortedIndices = sortedIndicesTensor.flat<Tindices>(); |
| 170 | + auto sortedIndicesCounter = sortedIndicesCounterTensor.flat<Tindices>(); |
| 171 | + // Note: I tried splitting the two kernels into different streams but |
| 172 | + // performance was barely affected. |
| 173 | + const Eigen::Index batch_dim = indices.dimension(0); |
| 174 | + const Eigen::Index bag_dim = indices.dimension(1); |
| 175 | + const Eigen::Index output_dim = params.dimension(1); |
| 176 | + const auto params_size = params.size(); |
| 177 | + const int kThreadsPerBlock = 32; |
| 178 | + dim3 gridShape = dim3(batch_dim, bag_dim, 1); |
| 179 | + TF_CHECK_OK(GpuLaunchKernel( |
| 180 | + EmbeddingBagWeightsGradKernel<T, Tindices, kThreadsPerBlock>, gridShape, |
| 181 | + kThreadsPerBlock, 0, d.stream(), output_dim, indices.data(), |
| 182 | + params.data(), grads.data(), weights_grads.data())); |
| 183 | + |
| 184 | + const int indices_size = indices.size(); |
| 185 | + const int values_size = params.size(); |
| 186 | + const int total_blocks = Eigen::divup(indices_size, kThreadsPerBlock); |
| 187 | + gridShape = dim3(total_blocks, 1, 1); |
| 188 | + |
| 189 | + TF_CHECK_OK(GpuLaunchKernel( |
| 190 | + PrepTempArraysKernel<Tindices, kThreadsPerBlock>, gridShape, |
| 191 | + kThreadsPerBlock, 0, d.stream(), indices.data(), sortedIndices.data(), |
| 192 | + sortedIndicesCounter.data(), indices_size)); |
| 193 | + |
| 194 | + thrust::device_ptr<Tindices> sortedIndicesCounterDevicePtr( |
| 195 | + sortedIndicesCounter.data()); |
| 196 | + thrust::device_ptr<Tindices> sortedIndicesDevicePtr(sortedIndices.data()); |
| 197 | + thrust::device_ptr<T> paramsGradDevicePtr(params_grads.data()); |
| 198 | + thrust::fill(paramsGradDevicePtr, |
| 199 | + paramsGradDevicePtr + static_cast<int>(params_size), |
| 200 | + static_cast<T>(0.0f)); |
| 201 | + thrust::sort_by_key(sortedIndicesDevicePtr, |
| 202 | + sortedIndicesDevicePtr + indices_size, |
| 203 | + sortedIndicesCounterDevicePtr); |
| 204 | + // Handle each row with as few thread blocks as possible |
| 205 | + int threadsPerBlock; |
| 206 | + int blocksPerRow; |
| 207 | + if (output_dim <= MAX_THREADS_PER_BLOCK) { |
| 208 | + blocksPerRow = 1; |
| 209 | + threadsPerBlock = output_dim; |
| 210 | + } else { |
| 211 | + blocksPerRow = |
| 212 | + Eigen::divup(static_cast<int>(output_dim), MAX_THREADS_PER_BLOCK); |
| 213 | + threadsPerBlock = |
| 214 | + Eigen::divup(static_cast<int>(output_dim), blocksPerRow); |
| 215 | + } |
| 216 | + // int blocksPerRow = 1; |
| 217 | + // while (threadsPerBlock > MAX_THREADS_PER_BLOCK) { |
| 218 | + // threadsPerBlock = (threadsPerBlock + 1) / 2; // Ceiling division |
| 219 | + // blocksPerRow *= 2; |
| 220 | + // } |
| 221 | + gridShape = dim3(indices_size, blocksPerRow, 1); |
| 222 | + TF_CHECK_OK(GpuLaunchKernel( |
| 223 | + EmbeddingBagValuesGradKernel<T, Tindices>, gridShape, threadsPerBlock, |
| 224 | + 0, d.stream(), output_dim, bag_dim, sortedIndices.data(), |
| 225 | + sortedIndicesCounter.data(), params.data(), weights.data(), |
| 226 | + grads.data(), params_grads.data())); |
| 227 | + } |
| 228 | +}; |
| 229 | + |
| 230 | +// Explicitly instantiate functors for the types of OpKernels registered. |
| 231 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, double, int32>; |
| 232 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, float, int32>; |
| 233 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, Eigen::half, int32>; |
| 234 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, double, int64>; |
| 235 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, float, int64>; |
| 236 | +template struct EmbeddingBagBackwardFunctor<GPUDevice, Eigen::half, int64>; |
| 237 | +} // namespace functor |
| 238 | +} // namespace addons |
| 239 | +} // namespace tensorflow |
| 240 | + |
| 241 | +#endif // GOOGLE_CUDA |
0 commit comments