diff --git a/detr_tf/custom_ops/ms_deform_attn/__init__.py b/detr_tf/custom_ops/ms_deform_attn/__init__.py new file mode 100644 index 00000000..f27f6062 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/__init__.py @@ -0,0 +1,25 @@ +import os.path +import tensorflow as tf + +if tf.test.is_built_with_cuda(): + _cuda_op_module = tf.load_op_library(os.path.join( + tf.compat.v1.resource_loader.get_data_files_path(), 'ms_deform_im2col.so')) + ms_deform_im2col = _cuda_op_module.ms_deform_im2col + + + @tf.RegisterGradient("MsDeformIm2col") + def _zero_out_grad(op, grad): + grad_value, grad_sampling_loc, grad_attn_weight = _cuda_op_module.ms_deform_im2col_grad( + op.inputs[0], + op.inputs[1], + op.inputs[2], + op.inputs[3], + op.inputs[4], + grad + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight + + +else: + raise ValueError("Trying to load cuda ms_deform_im2col without cuda support") \ No newline at end of file diff --git a/detr_tf/custom_ops/ms_deform_attn/build.sh b/detr_tf/custom_ops/ms_deform_attn/build.sh new file mode 100755 index 00000000..17a3057b --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/build.sh @@ -0,0 +1,7 @@ +# With tf env activated +TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) +TF_LFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) + +nvcc -std=c++11 -c -o ms_deform_im2col.o ms_deform_im2col.cu.cc ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr + +g++ -std=c++11 -shared -o ms_deform_im2col.so ms_deform_im2col.cc ms_deform_im2col.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]} diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py b/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py new file mode 100644 index 00000000..29b28aa7 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py @@ -0,0 +1,187 @@ +import tensorflow as tf + +#Python failback of MSDeformAttnFunction + +def MSDeformAttnFunction(values, sampling_locations, attention_weights): + + # for debug and test only, + # need to use cuda version instead + """ + :param values level, (N, H, W, num_heads, head_dim) + :param sampling_locations level, (N, Len_q, num_heads, num_sampling_points, 2) + :param attention_weights N, Len_q, num_heads, num_level, num_sampling_points + """ + + sampling_value_list = [] + for lid_, (value, sl) in enumerate(zip(values, sampling_locations)): + N, h_l, w_l, num_heads, head_dim = tf.unstack(tf.shape(value)) + # N*num_heads, h, w, c + value = tf.reshape(tf.transpose(value, [0, 3, 1, 2, 4]), [N*num_heads, h_l, w_l, head_dim]) + + # N, Len_q, num_heads, num_sampling_points, 2 + sl = 2 * sl - 1 #between (-1, 1) + N, Len_q, num_heads, num_sampling_points, _ = tf.unstack(tf.shape(sl)) + + # N*num_heads, Len_q, num_sampling_points, 2 + sampling_grid_l_ = tf.reshape(tf.transpose(sl, [0, 2, 1, 3, 4]), [N*num_heads, Len_q, num_sampling_points, 2]) + + #N*num_heads, Len_q, num_sampling_points, c + if True: + sampled_values = bilinear_sampler(value, sampling_grid_l_) + else: + sampled_values = nearest_sampler(value, sampling_grid_l_) + + sampling_value_list.append(sampled_values) + + # N*num_heads, Len_q, num_level, num_sampling_points, c + sampling_value = tf.stack(sampling_value_list, axis=2) + # N, num_heads, Len_q, num_level, num_sampling_points, c + sampling_value = tf.reshape(sampling_value, (N, num_heads, Len_q, len(values), num_sampling_points, head_dim)) + # N, Len_q, num_heads, num_level, num_sampling_points, c + sampling_value = tf.transpose(sampling_value, [0, 2, 1, 3, 4, 5]) + # (N, Len_q, num_heads, num_level, num_sampling_points, 1) + attention_weights = tf.expand_dims(attention_weights, -1) + # N, Len_q, num_heads, num_level, num_sampling_points, c + output = attention_weights * sampling_value + # N, Len_q, num_heads, -1, head_dim + output = tf.reshape(output, (N, Len_q, num_heads, -1, head_dim)) + # N, Len_q, num_heads, c + output = tf.reduce_sum(output, axis=3) + + output = tf.reshape(output, (N, Len_q, num_heads*head_dim)) + + return output + + +def within_bounds(x, lower, upper): + lower_tensor = tf.greater_equal(x, lower) + upper_tensor = tf.less_equal(x, upper) + return tf.logical_and(lower_tensor, upper_tensor) + +def bilinear_sampler(image, coords): + ''' Value sampler using tf.gather_nd + Args: + image: tensor with shape (bs, h, w, c) + coords: coordinates tensor with shape (bs, ... , 2), xy-indexing between 0, 1 + + Returns: + sampled tensor with shape (bs, ... , c) + ''' + + #Correspond to padding="zeros" (optimistic : discard only out of bound bilinear coefficient, not the full value) + + with tf.name_scope("bilinear_sampler"): + _, h, w, _ = tf.unstack(tf.shape(image)) + + + gx, gy = tf.unstack(coords, axis=-1) + + # rescale x and y to [0, W-1/H-1] + gx = (gx+1.0)/2.0 * tf.cast(w-1, tf.float32) + gy = (gy+1.0)/2.0 * tf.cast(h-1, tf.float32) + + gx0 = tf.floor(gx) + gx1 = gx0 + 1.0 + gy0 = tf.floor(gy) + gy1 = gy0 + 1.0 + + mx0 = within_bounds(gx0, 0, tf.cast(w, tf.float32)-1) + mx1 = within_bounds(gx1, 0, tf.cast(w, tf.float32)-1) + my0 = within_bounds(gy0, 0, tf.cast(h, tf.float32)-1) + my1 = within_bounds(gy1, 0, tf.cast(h, tf.float32)-1) + + c00 = tf.expand_dims((gy1 - gy)*(gx1 - gx), axis=-1) + c01 = tf.expand_dims((gy1 - gy)*(gx - gx0), axis=-1) + c10 = tf.expand_dims((gy - gy0)*(gx1 - gx), axis=-1) + c11 = tf.expand_dims((gy - gy0)*(gx - gx0), axis=-1) + + #clip for CPU (out_of_bound-error), optionnal on GPU (as corresponding m.. while be zeroed) + gx0 = tf.clip_by_value(gx0, 0, tf.cast(w, tf.float32)-1) + gx1 = tf.clip_by_value(gx1, 0, tf.cast(w, tf.float32)-1) + gy0 = tf.clip_by_value(gy0, 0, tf.cast(h, tf.float32)-1) + gy1 = tf.clip_by_value(gy1, 0, tf.cast(h, tf.float32)-1) + + g00 = tf.stack([gy0, gx0], axis=-1) + g01 = tf.stack([gy0, gx1], axis=-1) + g10 = tf.stack([gy1, gx0], axis=-1) + g11 = tf.stack([gy1, gx1], axis=-1) + + m00 = tf.cast(tf.expand_dims(tf.logical_and(my0, mx0), axis=-1), tf.float32) + m01 = tf.cast(tf.expand_dims(tf.logical_and(my0, mx1), axis=-1), tf.float32) + m10 = tf.cast(tf.expand_dims(tf.logical_and(my1, mx0), axis=-1), tf.float32) + m11 = tf.cast(tf.expand_dims(tf.logical_and(my1, mx1), axis=-1), tf.float32) + + x00 = tf.gather_nd(image, tf.cast(g00, dtype=tf.int32), batch_dims=1) + x01 = tf.gather_nd(image, tf.cast(g01, dtype=tf.int32), batch_dims=1) + x10 = tf.gather_nd(image, tf.cast(g10, dtype=tf.int32), batch_dims=1) + x11 = tf.gather_nd(image, tf.cast(g11, dtype=tf.int32), batch_dims=1) + + output = c00 * x00 * m00 \ + + c01 * x01 * m01 \ + + c10 * x10 * m10 \ + + c11 * x11 * m11 + + return output + + +def nearest_sampler(image, coords): + with tf.name_scope("nearest_sampler"): + _, h, w, _ = tf.unstack(tf.shape(image)) + + gx, gy = tf.unstack(coords, axis=-1) + + # rescale x and y to [0, W-1/H-1] + gx = (gx+1.0)/2.0 * tf.cast(w-1, tf.float32) + gy = (gy+1.0)/2.0 * tf.cast(h-1, tf.float32) + + gx0 = tf.round(gx) + gy0 = tf.round(gy) + + g00 = tf.stack([gy0, gx0], axis=-1) + + return tf.gather_nd(image, tf.cast(g00, dtype=tf.int32), batch_dims=1) + + + +if __name__ == "__main__": + import torch + import torch.nn.functional as F + + import numpy as np + + for i in range(1000): + + test_size = 100 + + grid_size = test_size + feature_len = 1 + batch_size = test_size + + grid_sampling_size = test_size + + values = np.random.rand(batch_size, grid_size, grid_size, feature_len) + + t_values = np.transpose(values, (0, 3, 1, 2) ) + + coords = np.random.rand(batch_size, grid_sampling_size, grid_sampling_size, 2) * 2 - 1 + coords = coords * 1.1 + + values = values.astype(np.float32) + coords = coords.astype(np.float32) + t_values = t_values.astype(np.float32) + + tf_result = bilinear_sampler(values, coords) + tf_result = tf_result.numpy() + + torch_result = F.grid_sample(torch.from_numpy(t_values), torch.from_numpy(coords), + mode='bilinear', padding_mode='zeros', align_corners=True) + + + torch_result = torch_result.view(batch_size, grid_sampling_size, grid_sampling_size, feature_len).numpy() + + diff = np.abs(tf_result - torch_result) + + print("diff", np.amax(diff), np.unravel_index(diff.argmax(), diff.shape)) + + if np.amax(diff) > 1e-3: + break diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc new file mode 100644 index 00000000..afcef79a --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc @@ -0,0 +1,226 @@ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +/* +:param values level, (N, H, W, num_heads, head_dim) +:param sampling_locations level, (N, Len_q, num_heads, num_sampling_points, 2) +:param attention_weights N, Len_q, num_heads, num_level, num_sampling_points +*/ + + +REGISTER_OP("MsDeformIm2col") + .Input("value: float") // (N, Len_in, n_heads, d_model//n_heads) + .Input("spatial_shapes: int32") // (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + .Input("level_start_index: int32") // (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + .Input("sampling_loc: float") // (N, Len_q, n_heads, n_levels, n_points, 2) + .Input("attn_weight: float") // (N, Len_q, num_heads, n_level, num_sampling_points) + .Attr("im2col_step:int = 64") + .Output("col: float") // N, Len_q, num_heads*head_dim + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + auto batch_size = c->Dim(c->input(0), 0); + auto num_heads = c->Dim(c->input(0), 2); + auto channels = c->Dim(c->input(0), 3); + auto num_query = c->Dim(c->input(3), 1); + auto outChannels = c->MakeDim(round(c->Value(num_heads)*c->Value(channels))); + c->set_output(0, c->MakeShape({batch_size, num_query, outChannels})); + + return Status::OK(); + }); + + + +REGISTER_OP("MsDeformIm2colGrad") + .Input("value: float") // (N, Len_in, n_heads, d_model//n_heads) + .Input("spatial_shapes: int32") // (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + .Input("level_start_index: int32") // (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + .Input("sampling_loc: float") // (N, Len_q, n_heads, n_levels, n_points, 2) + .Input("attn_weight: float") // (N, Len_q, num_heads, n_level, num_sampling_points) + .Input("grad_output: float") // N, Len_q, num_heads*head_dim + .Attr("im2col_step:int = 64") + .Output("grad_value: float") + .Output("grad_sampling_loc: float") + .Output("grad_attn_weight: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(3)); + c->set_output(2, c->input(4)); + return Status::OK(); + }); + + +void ms_deformable_col2im_cuda(const GPUDevice& d, + const float* grad_col, + const float* value, + const int * spatial_shapes, + const int * level_start_index, + const float * sampling_loc, + const float * attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* grad_value, + float* grad_sampling_loc, + float* grad_attn_weight); + +void ms_deformable_im2col_cuda(const GPUDevice& d, + const float* value, + const int* spatial_shapes, + const int* level_start_index, + const float* sampling_loc, + const float* attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* col); + + + + + +template +class MsDeformIm2colOp : public OpKernel { + public: + explicit MsDeformIm2colOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("im2col_step", &im2col_step_)); + OP_REQUIRES(context, im2col_step_ >= 0, + errors::InvalidArgument("Need im2col_step_ >= 0, got ", + im2col_step_)); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& value = context->input(0); + const Tensor& spatial_shapes = context->input(1); + const Tensor& level_start_index = context->input(2); + const Tensor& sampling_loc = context->input(3); + const Tensor& attn_weight = context->input(4); + + const int batch_size = value.dim_size(0); + const int spatial_size = value.dim_size(1); + const int num_heads = value.dim_size(2); + const int channels = value.dim_size(3); + const int num_levels = spatial_shapes.dim_size(0); + const int num_query = sampling_loc.dim_size(1); + const int num_point = sampling_loc.dim_size(4); + + const int im2col_step = std::min(batch_size, im2col_step_); + + Tensor* output_tensor = nullptr; + + TensorShape output_tensor_shape = TensorShape({batch_size, num_query, num_heads*channels}); + + OP_REQUIRES_OK(context, context->allocate_output(0, output_tensor_shape, &output_tensor)); + auto col = output_tensor->flat(); + + + // Call the cuda kernel launcher + ms_deformable_im2col_cuda(context->eigen_gpu_device(), + value.flat().data(), + spatial_shapes.flat().data(), + level_start_index.flat().data(), + sampling_loc.flat().data(), + attn_weight.flat().data(), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + col.data()); + } + private: + int im2col_step_; +}; + + + +template +class MsDeformIm2colGradOp : public OpKernel { + public: + explicit MsDeformIm2colGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("im2col_step", &im2col_step_)); + OP_REQUIRES(context, im2col_step_ >= 0, + errors::InvalidArgument("Need im2col_step_ >= 0, got ", + im2col_step_)); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& value = context->input(0); + const Tensor& spatial_shapes = context->input(1); + const Tensor& level_start_index = context->input(2); + const Tensor& sampling_loc = context->input(3); + const Tensor& attn_weight = context->input(4); + const Tensor& grad_output = context->input(5); + + const int batch_size = value.dim_size(0); + const int spatial_size = value.dim_size(1); + const int num_heads = value.dim_size(2); + const int channels = value.dim_size(3); + const int num_levels = spatial_shapes.dim_size(0); + const int num_query = sampling_loc.dim_size(1); + const int num_point = sampling_loc.dim_size(4); + + Tensor* output_tensor_value = nullptr; + Tensor* output_tensor_sampling_loc = nullptr; + Tensor* output_tensor_attn_weight = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, value.shape(), &output_tensor_value)); + OP_REQUIRES_OK(context, context->allocate_output(1, sampling_loc.shape(), &output_tensor_sampling_loc)); + OP_REQUIRES_OK(context, context->allocate_output(2, attn_weight.shape(), &output_tensor_attn_weight)); + + auto output_flat = output_tensor_value->flat(); + + + // Call the cuda kernel launcher + ms_deformable_col2im_cuda(context->eigen_gpu_device(), + grad_output.flat().data(), + value.flat().data(), + spatial_shapes.flat().data(), + level_start_index.flat().data(), + sampling_loc.flat().data(), + attn_weight.flat().data(), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + output_tensor_value->template flat().data(), + output_tensor_sampling_loc->template flat().data(), + output_tensor_attn_weight->template flat().data()); + + } + private: + int im2col_step_; +}; + + + + + + +REGISTER_KERNEL_BUILDER(Name("MsDeformIm2col").Device(DEVICE_GPU), MsDeformIm2colOp); + +REGISTER_KERNEL_BUILDER(Name("MsDeformIm2colGrad").Device(DEVICE_GPU), MsDeformIm2colGradOp); \ No newline at end of file diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc new file mode 100644 index 00000000..6410a01b --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc @@ -0,0 +1,1353 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +//#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" +#include "tensorflow/core/util/gpu_cuda_alias.h" +//#include "tensorflow/core/util/gpu_device_functions.h" + +typedef Eigen::GpuDevice GPUDevice; + + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = (width-1) * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = (height-1) * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); //- 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); //- 0.5; + + // if (h_im >= 0 && w_im >= 0 && h_im <= spatial_h-1 && w_im <= spatial_w-1) + // { + // const int h_r = round(h_im); + // const int w_r = round(w_im); + // const int w_stride = num_heads * channels; + // const int h_stride = spatial_w * w_stride; + // const int h_ptr_offset = h_r * h_stride; + // const int w_ptr_offset = w_r * w_stride; + // const int base_ptr = m_col * channels + c_col; + + // const int ptr1 = h_ptr_offset + w_ptr_offset + base_ptr; + + // col += data_value_ptr[ptr1]* weight; + // } + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +void ms_deformable_im2col_cuda(const GPUDevice &d, + const float* data_value, + const int* data_spatial_shapes, + const int* data_level_start_index, + const float* data_sampling_loc, + const float* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + + +void ms_deformable_col2im_cuda(const GPUDevice &d, + const float* grad_col, + const float* data_value, + const int * data_spatial_shapes, + const int * data_level_start_index, + const float * data_sampling_loc, + const float * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* grad_value, + float* grad_sampling_loc, + float* grad_attn_weight) +{ + + cudaMemset(grad_value, 0, batch_size*num_heads*channels*spatial_size*sizeof(float)); + cudaMemset(grad_sampling_loc, 0, batch_size*num_query*num_heads*num_levels*num_point*2*sizeof(float)); + cudaMemset(grad_attn_weight, 0, batch_size*num_query*num_heads*num_levels*num_point*sizeof(float)); + + + + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o new file mode 100644 index 00000000..e025baa6 Binary files /dev/null and b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o differ diff --git a/detr_tf/custom_ops/ms_deform_attn/test.py b/detr_tf/custom_ops/ms_deform_attn/test.py new file mode 100644 index 00000000..80eadbc1 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/test.py @@ -0,0 +1,206 @@ + +"""Cuda op Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import numpy as np +import tensorflow as tf + +import math + +from surroundnet.custom_ops.ms_deform_attn import ms_deform_im2col + +from surroundnet.custom_ops.ms_deform_attn.ms_deform_attn import MSDeformAttnFunction + + + +import torch +import torch.nn.functional as F + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + + #print(sampling_locations[:, :, :, 1, :, :]) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=True) + # N_*M_, Lq_, P_, D_ + #tmp = sampling_value_l_.permute(0, 2, 3, 1) + #print( f"sampling_value_l_{lid_}", tmp ) + # tmp = value_l_.permute(0, 2, 3, 1) + # print("value", tmp) + + # tmp = sampling_value_l_.permute(0, 2, 3, 1) + # print("sampled_values", tmp) + + # print("sampling_grid_l_", sampling_grid_l_) + # exit() + + sampling_value_list.append(sampling_value_l_) + #exit() + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + # (N_*M_, D_, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + #print(N_, M_, D_, Lq_, L_, P_) + + tmp = torch.stack(sampling_value_list, dim=-2).flatten(-2) + + # print("MSDeformAttnFunction_attention_weights", attention_weights.shape, attention_weights.view(N_, M_, 1, Lq_, L_, P_)) + # print("MSDeformAttnFunction_sampl", tmp.shape, tmp.view(N_, M_, D_, Lq_, L_, P_)) + + #print("MSDeformAttnFunction_att*sampl", output.shape, output.view(N_, M_, D_, Lq_, L_, P_)[0, 2, :, 1, 3, 2]) + + # (N_*M_, D_, Lq_) -> (N_, M_*D_, Lq_) + output = output.sum(-1).view(N_, M_*D_, Lq_) + # (N_, Lq_, M_*D_) + return output.transpose(1, 2).contiguous() + + + + + + + + + + +N = 1 + +n_heads = 8 +d_model = 256 +size = np.array( (128, 128) ) +Len_q = 13 + + +# n_heads = 4 +# d_model = 4 +# size = np.array( (16, 16) ) + +# Len_q = 3 + +n_levels = 4 +num_sampling_points = 4 + +values = list() +spatial_shapes = list() +level_start_index = [0] + +for i in range(n_levels): + value = tf.random.uniform( shape=(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) ) * 0.01 + #value = tf.ones( shape=(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) ) + values.append(value) + spatial_shapes.append(size) + level_start_index.append(size[0]*size[1] + level_start_index[-1]) + + size = size//2 + +flatten_attn_weight = tf.random.uniform( (N, Len_q, n_heads, n_levels, num_sampling_points) ) + 1e-5 +flatten_attn_weight /= tf.reduce_sum(tf.reduce_sum(flatten_attn_weight, axis=-1, keepdims=True), axis=-2, keepdims=True) + +#flatten_attn_weight = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points) ) + + +flatten_sampling_loc = tf.random.uniform( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), minval=-0.1, maxval=1.1, dtype=tf.float32 ) +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) *10 #(127+0.999) #* math.pi /10 +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) *0.5 #* math.pi /10 +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) * math.pi /10 + + + +level_start_index = np.array( level_start_index, dtype=np.int32 ) + +spatial_shapes = np.array( spatial_shapes, dtype=np.int32 ) + + +with tf.GradientTape(persistent=True) as g: + g.watch(flatten_sampling_loc) + g.watch(values) + g.watch(flatten_attn_weight) + + sampling_loc = tf.unstack(flatten_sampling_loc, axis=3) #(N, Len_q, n_heads, num_sampling_points) + flatten_value = tf.concat( [tf.reshape(v, (N, -1, n_heads, d_model//n_heads) ) for v in values], axis=1) + + py_res = MSDeformAttnFunction(values, sampling_loc, flatten_attn_weight) + + res = ms_deform_im2col( + flatten_value, # (N, Len_in, n_heads, d_model//n_heads) + spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + flatten_sampling_loc, # (N, Len_q, n_heads, n_levels, n_points, 2) + flatten_attn_weight # (N, Len_q, n_heads, n_level, n_points) + ) + +# Save tensors to npy for TensorRT plugin test +# np.save("./flatten_value.npy", flatten_value.numpy()) +# np.save("./spatial_shapes.npy", spatial_shapes) +# np.save("./level_start_index.npy", level_start_index) +# np.save("./flatten_sampling_loc.npy", flatten_sampling_loc.numpy()) +# np.save("./flatten_attn_weight.npy", flatten_attn_weight.numpy()) +# np.save("./ms_deform_im2col_out.npy", res.numpy()) + +def check_value(name, py_grad, cu_grad): + print(name, py_grad.shape) + print("\t min value :", tf.reduce_min(py_grad), tf.reduce_min(cu_grad)) + print("\t max value :", tf.reduce_max(py_grad), tf.reduce_max(cu_grad)) + print("\t mean value :", tf.reduce_mean(py_grad), tf.reduce_mean(cu_grad)) + print("\t std value :", tf.math.reduce_std(py_grad), tf.math.reduce_std(cu_grad)) + abs_err = tf.reduce_max(tf.abs(py_grad - cu_grad)) + #coord = tf.math.argmax(tf.abs(py_grad - cu_grad)) + print("\t max abs error :", abs_err) + abs_err = tf.reduce_mean(tf.abs(py_grad - cu_grad)) + print("\t mean abs error :", abs_err) + #rel_err = tf.reduce_max(tf.abs(py_grad - cu_grad)/(tf.math.sqrt( tf.abs(py_grad)*tf.abs(cu_grad) ) + 1e-3) ) + #print("\t max rel error :", rel_err) + +check_value("VALUE python / CUDA", py_res, res) +#print(py_res) +#print(res) + +pytorch_res = ms_deform_attn_core_pytorch( + torch.from_numpy(flatten_value.numpy()), + spatial_shapes, + torch.from_numpy(flatten_sampling_loc.numpy()), + torch.from_numpy(flatten_attn_weight.numpy()) ) + +check_value("VALUE pytorch / tensorflow", pytorch_res, res) + + +#print(pytorch_res) +check_value("GRAD Sampling Loc", g.gradient(py_res, flatten_sampling_loc), g.gradient(res, flatten_sampling_loc)) +check_value("GRAD Value", g.gradient(py_res, values[0]), g.gradient(res, values[0])) +check_value("GRAD Attention", g.gradient(py_res, flatten_attn_weight), g.gradient(res, flatten_attn_weight)) + + +#(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) +#py_gvalue = g.gradient(py_res, values[0]) +#cu_gvalue = g.gradient(res, values[0]) +#print("cu_gvalue", cu_gvalue) +#print("cu_gvalue", cu_gvalue) + +# args = [ +# flatten_value, # (N, Len_in, n_heads, d_model#n_heads) +# spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] +# level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] +# flatten_sampling_loc, # (N, Len_q, n_heads, n_levels, n_points, 2) +# flatten_attn_weight # (N, Len_q, n_heads, n_level, n_points) +# ] + +#CUDA +#numerical, theoric = tf.test.compute_gradient(ms_deform_im2col, args, delta=0.001) \ No newline at end of file diff --git a/detr_tf/data/coco.py b/detr_tf/data/coco.py index e8f48557..d5f19bd9 100644 --- a/detr_tf/data/coco.py +++ b/detr_tf/data/coco.py @@ -72,6 +72,12 @@ def get_coco_from_id(coco_id, coco, augmentation, config, img_dir): # Apply augmentations if len(t_bbox) > 0 and augmentation is not None: image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation) + + # If instance into the image, set at least one bbox with -1 everywhere + # This kind of bbox and class will be ignore at training + if len(t_bbox) == 0: t_bbox = np.zeros((1, 4)) - 1 + if len(t_class) == 0: t_class = np.zeros((1, 4)) - 1 + # Normalized images image = processing.normalized_images(image, config) # Set type for tensorflow @@ -79,17 +85,62 @@ def get_coco_from_id(coco_id, coco, augmentation, config, img_dir): t_bbox = t_bbox.astype(np.float32) t_class = t_class.astype(np.int64) is_crowd = np.array(is_crowd, dtype=np.int64) - return image, t_bbox, t_class, is_crowd + + return image, t_bbox, t_class#, is_crowd + + +def tensor_to_ragged(image, t_bbox, t_class): + # Images can have different size in multi-scale training + # Also, each image can have different number of instance. + # Therefore, we can use ragged tensor to handle Tensor with dynamic shapes. + # None is consider as Dynamic in the shape by the Ragged Tensor. + image.set_shape(tf.TensorShape([None, None, 3])) + image = tf.RaggedTensor.from_tensor(image).to_tensor() + t_bbox.set_shape(tf.TensorShape([None, 4])) + t_bbox = tf.RaggedTensor.from_tensor(t_bbox).to_tensor() + t_class.set_shape(tf.TensorShape([None, 1])) + t_class = tf.RaggedTensor.from_tensor(t_class).to_tensor() + return image, t_bbox, t_class + + +def iter_tuple_to_dict(data): + image, t_bbox, t_class = data + return { + "images": image, + "target_bbox": t_bbox, + "target_class": t_class + } -def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None): +def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None, shuffle_data=True): """ Load a coco dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir + shuffle : bool + Shuffle the dataset by default """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir - ann_file = config.data.ann_file if ann_file is None else ann_file - img_dir = config.data.img_dir if img_dir is None else img_dir - - + if ann_dir is None: + ann_file = config.data.ann_file if ann_file is None else os.path.join(config.data_dir, ann_file) + else: + ann_file = config.data.ann_file if ann_file is None else os.path.join(ann_dir, ann_file) + img_dir = config.data.img_dir if img_dir is None else os.path.join(config.data_dir, img_dir) coco = COCO(ann_file) @@ -106,22 +157,24 @@ def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_ # Setup the data pipeline img_ids = coco.getImgIds() - shuffle(img_ids) + + if shuffle_data: + shuffle(img_ids) dataset = tf.data.Dataset.from_tensor_slices(img_ids) # Shuffle the dataset - dataset = dataset.shuffle(1000) + if shuffle_data: + dataset = dataset.shuffle(1000) + # Retrieve img and labels - outputs_types=(tf.float32, tf.float32, tf.int64, tf.int64) + outputs_types=(tf.float32, tf.float32, tf.int64) dataset = dataset.map(lambda idx: processing.numpy_fc( idx, get_coco_from_id, outputs_types=outputs_types, coco=coco, augmentation=augmentation, config=config, img_dir=img_dir) , num_parallel_calls=tf.data.experimental.AUTOTUNE) - dataset = dataset.filter(lambda imgs, tbbox, tclass, iscrowd: tf.shape(tbbox)[0] > 0 and iscrowd != 1) - dataset = dataset.map(lambda imgs, tbbox, tclass, iscrowd: (imgs, tbbox, tclass), num_parallel_calls=tf.data.experimental.AUTOTUNE) - # Pad bbox and labels - dataset = dataset.map(processing.pad_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) - - dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.map(tensor_to_ragged, num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size, drop_remainder=True)) dataset = dataset.prefetch(32) + + dataset.itertuple2dict = lambda data: iter_tuple_to_dict(data) return dataset, class_names \ No newline at end of file diff --git a/detr_tf/data/processing.py b/detr_tf/data/processing.py index 4629c398..be72b24c 100644 --- a/detr_tf/data/processing.py +++ b/detr_tf/data/processing.py @@ -28,8 +28,14 @@ def numpy_fc(idx, fc, outputs_types=(tf.float32, tf.float32, tf.int64), **params Call a numpy function on each given ID (`idx`) and load the associated image and labels (bbbox and cls) """ def _np_function(_idx): - return fc(_idx, **params) - return tf.numpy_function(_np_function, [idx], outputs_types) + data = fc(_idx, **params) + return data + + data = tf.numpy_function(_np_function, [idx], outputs_types) + + #data = tuple(map(lambda x : tf.RaggedTensor.from_tensor(x).to_tensor(), data)) + + return data def pad_labels(images: tf.Tensor, t_bbox: tf.Tensor, t_class: tf.Tensor): diff --git a/detr_tf/data/tfcsv.py b/detr_tf/data/tfcsv.py index 3503a56e..e2c27633 100644 --- a/detr_tf/data/tfcsv.py +++ b/detr_tf/data/tfcsv.py @@ -36,7 +36,27 @@ def load_data_from_index(index, class_names, filenames, anns, config, augmentati def load_tfcsv_dataset(config, batch_size, augmentation=False, exclude=[], ann_dir=None, ann_file=None, img_dir=None): - """ Load the hardhat dataset + """ Load a Tensorflow csv Dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + exclude: list + Exclude some class from the training. Nothing happen if empty. + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir ann_file = config.data.ann_file if ann_file is None else ann_file diff --git a/detr_tf/data/transformation.py b/detr_tf/data/transformation.py index be0e9cc8..9c2e3eac 100644 --- a/detr_tf/data/transformation.py +++ b/detr_tf/data/transformation.py @@ -2,12 +2,106 @@ import imgaug as ia import imgaug.augmenters as iaa import numpy as np +import random from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage from imgaug.augmentables.segmaps import SegmentationMapsOnImage import tensorflow as tf + +def get_size_with_aspect_ratio(w, h, size, max_size=None): + + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + +def get_multiscale_transform(images, + scales=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800], + random_crop=(384, 600), + random_resize=[400, 500, 600], + max_size=None): + """ Coco Augmentation pipeline + """ + h, w, _ = images.shape + + one_of = [] + + scale = np.random.choice(scales) + scale_height, scake_width = get_size_with_aspect_ratio(w, h, scale, max_size=max_size) + one_of.append( + iaa.Resize({"height": scale_height, "width": scake_width}) + ) + + random_resize_crop = [] + if random_resize is not None and len(random_resize) > 0: + scale = np.random.choice(random_resize) + resize_height, resize_width = get_size_with_aspect_ratio(w, h, scale) + random_resize_crop.append( + iaa.Resize({"height": resize_height, "width": resize_width}) + ) + if random_crop is not None: + crop_width = random.randint(random_crop[0], random_crop[1]) + crop_height = random.randint(random_crop[0], random_crop[1]) + random_resize_crop.append( + iaa.CropToFixedSize(crop_width, crop_height) + ) + + random_resize_crop.append( + iaa.Resize({"height": scale_height, "width": scake_width}) + ) + + one_of.append(iaa.Sequential(random_resize_crop)) + + seq = iaa.OneOf(one_of) + return seq + + + +def get_train_fixedsize_transform(image_size): + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + seq = iaa.Sequential([ + iaa.Fliplr(0.5), # horizontal flips + sometimes(iaa.OneOf([ + # Resize complety the image + iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL), + # Crop into the image + iaa.CropToFixedSize(image_size[1], image_size[0]), + # Affine transform + iaa.Affine( + scale={"x": (0.5, 1.5), "y": (0.5, 1.5)}, + ) + ])), + # Be sure to resize to the target image size + iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL) + ], random_order=False) # apply augmenters in random order + return seq + + +def get_valid_fixedsize_transform(image_size): + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + seq = iaa.Sequential([ + # Be sure to resize to the target image size + iaa.Resize({"width": image_size[1], "height": image_size[0]}) + ], random_order=False) # apply augmenters in random order + return seq + + def bbox_xcyc_wh_to_imgaug_bbox(bbox, target_class, height, width): img_aug_bbox = [] @@ -64,52 +158,23 @@ def detr_aug_seq(image, config, augmenation): max_side_max = 1333 image_size = config.image_size - if augmenation: - - seq = iaa.Sequential([ - iaa.Fliplr(0.5), # horizontal flips - sometimes(iaa.OneOf([ - # Resize complety the image - iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL), - # Crop into the image - iaa.CropToFixedSize(image_size[1], image_size[0]), - # Affine transform - iaa.Affine( - scale={"x": (0.5, 1.5), "y": (0.5, 1.5)}, - ) - ])), - # Be sure to resize to the target image size - iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL) - ], random_order=False) # apply augmenters in random order - - return seq - + # Multi scale training + if image_size is None: + if augmenation: + return get_multiscale_transform(image, max_size=1300) + else: + return get_multiscale_transform( + image, + scales=[800], + random_crop=None, + random_resize=None, + max_size=1300 + ) else: - - seq = iaa.Sequential([ - # Be sure to resize to the target image size - iaa.Resize({"width": image_size[1], "height": image_size[0]}) - ], random_order=False) # apply augmenters in random order - - return seq - - """ Mode paper evaluation - # Evaluation mode, we took the largest min side the model is trained on - target_min_side_size = 480 - image_min_side = min(float(image.shape[0]), float(image.shape[1])) - image_max_side = max(float(image.shape[0]), float(image.shape[1])) - - min_side_scaling = target_min_side_size / image_min_side - max_side_scaling = max_side_max / image_max_side - scaling = min(min_side_scaling, max_side_scaling) - - n_height = int(scaling * image.shape[0]) - n_width = int(scaling * image.shape[1]) - - seq = iaa.Sequential([ - iaa.Resize({"height": n_height, "width": n_width}), - ]) - """ + if augmenation: + return get_train_fixedsize_transform(image_size) + else: + return get_valid_fixedsize_transform(image_size) return seq diff --git a/detr_tf/data/voc.py b/detr_tf/data/voc.py index 34e88c30..8390e672 100644 --- a/detr_tf/data/voc.py +++ b/detr_tf/data/voc.py @@ -19,9 +19,9 @@ 'sheep', 'sofa', 'train', 'tvmonitor' ] -def load_voc_labels(img_id, class_names, voc_dir, augmentation, config): +def load_voc_labels(img_id, class_names, voc_dir, ann_dir, augmentation, config): - anno_path = os.path.join(voc_dir, config.data.ann_dir, img_id + '.xml') + anno_path = os.path.join(voc_dir, ann_dir, img_id + '.xml') objects = ET.parse(anno_path).findall('object') size = ET.parse(anno_path).find('size') width = float(size.find("width").text) @@ -55,13 +55,13 @@ def load_voc_labels(img_id, class_names, voc_dir, augmentation, config): return t_bbox, t_class -def load_voc_from_id(img_id, class_names, voc_dir, augmentation, config, img_dir): +def load_voc_from_id(img_id, class_names, voc_dir, ann_dir, augmentation, config, img_dir): img_id = str(img_id.decode()) # Load image - img_path = os.path.join(voc_dir, config.data.img_dir, img_id + '.jpg') + img_path = os.path.join(voc_dir, img_dir, img_id + '.jpg') image = imageio.imread(img_path) # Load labels - t_bbox, t_class = load_voc_labels(img_id, class_names, voc_dir, augmentation, config) + t_bbox, t_class = load_voc_labels(img_id, class_names, voc_dir, ann_dir, augmentation, config) # Apply augmentations if augmentation is not None: image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation) @@ -77,7 +77,25 @@ def load_voc_from_id(img_id, class_names, voc_dir, augmentation, config, img_dir def load_voc_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None): - """ + """ Load a VOC dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir ann_file = config.data.ann_file if ann_file is None else ann_file @@ -115,7 +133,7 @@ def load_voc_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_f dataset = dataset.shuffle(1000) # Retrieve img and labels dataset = dataset.map(lambda idx: processing.numpy_fc(idx, load_voc_from_id, - class_names=class_names, voc_dir=config.data.data_dir, augmentation=augmentation, config=config, img_dir=img_dir) + class_names=class_names, voc_dir=config.data.data_dir, ann_dir=ann_dir, augmentation=augmentation, config=config, img_dir=img_dir) , num_parallel_calls=tf.data.experimental.AUTOTUNE) # Filter labels to be sure to keep only sample with at least one bbox dataset = dataset.filter(lambda imgs, tbbox, tclass: tf.shape(tbbox)[0] > 0) diff --git a/detr_tf/inference.py b/detr_tf/inference.py index c82fa1c0..957233fd 100644 --- a/detr_tf/inference.py +++ b/detr_tf/inference.py @@ -65,24 +65,34 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[ return image -def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center"): +def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center", threshold=None): + + #print('get model inference', [key for key in m_outputs]) + + # Detr or deformable + predicted_bbox = m_outputs["pred_boxes"][0] if "pred_boxes" in m_outputs else m_outputs["bbox_pred_boxes"][0] + predicted_labels = m_outputs["pred_logits"][0] if "pred_logits" in m_outputs else m_outputs["bbox_pred_logits"][0] + activation = "softmax" if "pred_boxes" in m_outputs else "sigmoid" + + if activation == "softmax": # Detr + softmax = tf.nn.softmax(predicted_labels) + predicted_scores = tf.reduce_max(softmax, axis=-1) + predicted_labels = tf.argmax(softmax, axis=-1) + indices = tf.where(predicted_labels != background_class) + indices = tf.squeeze(indices, axis=-1) + else: # Deformable Detr + sigmoid = tf.nn.sigmoid(predicted_labels) + predicted_scores = tf.reduce_max(sigmoid, axis=-1) + predicted_labels = tf.argmax(sigmoid, axis=-1) + threshold = 0.1 if threshold is None else threshold + indices = tf.where(predicted_scores > threshold) + indices = tf.squeeze(indices, axis=-1) - predicted_bbox = m_outputs["pred_boxes"][0] - predicted_labels = m_outputs["pred_logits"][0] - - softmax = tf.nn.softmax(predicted_labels) - predicted_scores = tf.reduce_max(softmax, axis=-1) - predicted_labels = tf.argmax(softmax, axis=-1) - - - indices = tf.where(predicted_labels != background_class) - indices = tf.squeeze(indices, axis=-1) predicted_scores = tf.gather(predicted_scores, indices) predicted_labels = tf.gather(predicted_labels, indices) predicted_bbox = tf.gather(predicted_bbox, indices) - if bbox_format == "xy_center": predicted_bbox = predicted_bbox elif bbox_format == "xyxy": diff --git a/detr_tf/logger/training_logging.py b/detr_tf/logger/training_logging.py index 2291033e..c3560e29 100644 --- a/detr_tf/logger/training_logging.py +++ b/detr_tf/logger/training_logging.py @@ -21,13 +21,13 @@ RAGGED = False -def tf_send_batch_log_to_wandb(images, target_bbox, target_class, m_outputs: dict, config, class_name=[], step=None, prefix=""): +def tf_send_batch_log_to_wandb(images, target_bbox, target_class, m_outputs: dict, config, batch_size, class_name=[], step=None, prefix=""): # Warning: In graph mode, this class is init only once. In eager mode, this class is init at each step. img_sender = WandbSender() predicted_bbox = m_outputs["pred_boxes"] - for b in range(predicted_bbox.shape[0]): + for b in range(batch_size): # Select within the batch the elements at indice b image = images[b] @@ -67,17 +67,15 @@ def compute_map_on_batch(images, target_bbox, target_class, m_outputs: dict, co # Target t_bbox, t_class = target_bbox[b], target_class[b] - if not RAGGED: - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) - t_bbox = bbox.xcycwh_to_yx_min_yx_max(t_bbox) - t_class = tf.slice(t_class, [1, 0], [size, -1]) - t_class = tf.squeeze(t_class, axis=-1) + #t_class = tf.slice(t_class, [1, 0], [size, -1]) + t_bbox = bbox.xcycwh_to_yx_min_yx_max(t_bbox) + t_class = tf.squeeze(t_class, axis=-1) # Inference ops predicted_bbox, predicted_labels, predicted_scores = get_model_inference(elem_m_outputs, config.background_class, bbox_format="yxyx") pred_mask = None - + + # Fake masks (durty adapted code) pred_mask = np.zeros((138, 138, len(predicted_bbox))) target_mask = np.zeros((138, 138, len(t_bbox))) WandbSender.compute_map( @@ -89,18 +87,30 @@ def compute_map_on_batch(images, target_bbox, target_class, m_outputs: dict, co -def train_log(images, t_bbox, t_class, m_outputs: dict, config, step, class_name=[], prefix="train/"): - # Every 1000 steps, log some progress of the training +def train_log(data, m_outputs: dict, config, step, class_name=[], prefix="train/"): + # Every x steps, log some progress of the training # (Images with bbox and images logs) if step % 100 == 0: - tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=step, prefix=prefix) + tf_send_batch_log_to_wandb(data["images"], data["target_bbox"], data["target_class"], m_outputs, config, config.batch_size, class_name=class_name, step=step, prefix=prefix) -def valid_log(images, t_bbox, t_class, m_outputs: dict, config, step, global_step, class_name=[], evaluation_step=200, prefix="train/"): +def valid_log(data: dict, m_outputs: dict, config, batch_size, step, global_step, class_name=[], evaluation_step=200, prefix="train/"): # Set the number of class WandbSender.init_ap_data(nb_class=len(class_name)) - map_list = compute_map_on_batch(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, send=(step+1==evaluation_step), prefix="val/") + + # Compute AP + map_list = compute_map_on_batch( + images=data["images"], + target_bbox=data["target_bbox"], + target_class=data["target_class"], + m_outputs=m_outputs, + config=config, + class_name=class_name, + step=global_step, + send=(step+1==evaluation_step), + prefix="val/" + ) if step == 0: - tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, prefix="val/") + tf_send_batch_log_to_wandb(data["images"], data["target_bbox"], data["target_class"], m_outputs, config, batch_size, class_name=class_name, step=global_step, prefix="val/") diff --git a/detr_tf/logger/wandb_logging.py b/detr_tf/logger/wandb_logging.py index b3b0d3a9..b2f10b30 100644 --- a/detr_tf/logger/wandb_logging.py +++ b/detr_tf/logger/wandb_logging.py @@ -118,7 +118,7 @@ def compute_map(p_bbox: np.array, p_labels: np.array, p_scores: np.array, t_bbox except Exception as e: print("compute_map error. e=", e) - #raise e + raise e return np.array([0.0, 0.0], np.float64) return np.array([0.0, 0.0], np.float64) diff --git a/detr_tf/loss/compute_map.py b/detr_tf/loss/compute_map.py index fd7412c5..6b5027ec 100644 --- a/detr_tf/loss/compute_map.py +++ b/detr_tf/loss/compute_map.py @@ -182,13 +182,6 @@ def print_maps(all_maps): def cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, gt_classes, t_mask, ap_data, iou_thresholds): - #print("p_bbox", p_bbox.shape) - #print("p_labels", p_labels.shape) - #print("p_scores", p_scores.shape) - #print("p_mask", p_mask.shape) - #print("t_bbox", t_bbox.shape) - #print("gt_classes", gt_classes) - #print("t_mask", t_mask.shape) num_crowd = 0 @@ -220,8 +213,7 @@ def cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, gt_classes, t_mask, ap_d lambda i,j: crowd_mask_iou_cache[i,j].item(), lambda i: mask_scores[i], mask_indices) ] - #print("run", list(classes), list(gt_classes)) - #print(classes + gt_classes) + for _class in set(list(classes) + list(gt_classes)): ap_per_iou = [] num_gt_for_class = sum([1 for x in gt_classes if x == _class]) diff --git a/detr_tf/loss/hungarian_matching.py b/detr_tf/loss/hungarian_matching.py index 35658307..301b4fc2 100644 --- a/detr_tf/loss/hungarian_matching.py +++ b/detr_tf/loss/hungarian_matching.py @@ -162,11 +162,13 @@ def loss_boxes(outputs, targets, indices, num_boxes): def hungarian_matching(t_bbox, t_class, p_bbox, p_class, fcost_class=1, fcost_bbox=5, fcost_giou=2, slice_preds=True) -> tuple: - if slice_preds: - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) - t_class = tf.slice(t_class, [1, 0], [size, -1]) - t_class = tf.squeeze(t_class, axis=-1) + t_class = tf.squeeze(t_class, axis=-1) + _filter = tf.squeeze(tf.where(t_class != -1), axis=-1) + #print("t_class", t_class.shape) + t_class = tf.gather(t_class, _filter) + #print('t_class', t_class.shape) + t_bbox = tf.gather(t_bbox, _filter) + #print('t_bbox', t_bbox.shape) # Convert frpm [xc, yc, w, h] to [xmin, ymin, xmax, ymax] p_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(p_bbox) diff --git a/detr_tf/loss/loss.py b/detr_tf/loss/loss.py index 2d87e488..1cd4eee9 100644 --- a/detr_tf/loss/loss.py +++ b/detr_tf/loss/loss.py @@ -19,13 +19,13 @@ def get_total_losss(losses): return total_loss -def get_losses(m_outputs, t_bbox, t_class, config): - losses = get_detr_losses(m_outputs, t_bbox, t_class, config) +def get_losses(m_outputs, t_bbox, t_class, config, batch_size): + losses = get_detr_losses(m_outputs, t_bbox, t_class, config, batch_size) # Get auxiliary loss for each auxiliary output if "aux" in m_outputs: for a, aux_m_outputs in enumerate(m_outputs["aux"]): - aux_losses = get_detr_losses(aux_m_outputs, t_bbox, t_class, config, suffix="_{}".format(a)) + aux_losses = get_detr_losses(aux_m_outputs, t_bbox, t_class, config, batch_size, suffix="_{}".format(a)) losses.update(aux_losses) # Compute the total loss @@ -95,7 +95,7 @@ def loss_boxes(p_bbox, p_class, t_bbox, t_class, t_indices, p_indices, t_selecto return loss_giou, l1_loss -def get_detr_losses(m_outputs, target_bbox, target_label, config, suffix=""): +def get_detr_losses(m_outputs, target_bbox, target_label, config, batch_size, suffix=""): predicted_bbox = m_outputs["pred_boxes"] predicted_label = m_outputs["pred_logits"] @@ -112,9 +112,10 @@ def get_detr_losses(m_outputs, target_bbox, target_label, config, suffix=""): t_offset = 0 p_offset = 0 - for b in range(predicted_bbox.shape[0]): + for b in range(batch_size): p_bbox, p_class, t_bbox, t_class = predicted_bbox[b], predicted_label[b], target_bbox[b], target_label[b] + t_indices, p_indices, t_selector, p_selector, t_bbox, t_class = hungarian_matching(t_bbox, t_class, p_bbox, p_class, slice_preds=True) t_indices = t_indices + tf.cast(t_offset, tf.int64) diff --git a/detr_tf/networks/custom_layers.py b/detr_tf/networks/custom_layers.py index 1a0c524e..bfe0b967 100644 --- a/detr_tf/networks/custom_layers.py +++ b/detr_tf/networks/custom_layers.py @@ -28,23 +28,27 @@ def compute_output_shape(self, input_shape): return input_shape + class Linear(tf.keras.layers.Layer): ''' Use this custom layer instead of tf.keras.layers.Dense to allow loading converted PyTorch Dense weights that have shape (output_dim, input_dim) ''' - def __init__(self, output_dim, **kwargs): + def __init__(self, output_dim, kernel_initializer=tf.keras.initializers.GlorotUniform(), bias_initializer=tf.keras.initializers.GlorotUniform(), **kwargs): super().__init__(**kwargs) self.output_dim = output_dim + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + def build(self, input_shape): self.kernel = self.add_weight(name='kernel', shape=[self.output_dim, input_shape[-1]], - initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + initializer=self.kernel_initializer, trainable=True) self.bias = self.add_weight(name='bias', shape=[self.output_dim], - initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + initializer=self.bias_initializer, trainable=True) def call(self, x): return tf.matmul(x, self.kernel, transpose_b=True) + self.bias @@ -65,3 +69,37 @@ def build(self, input_shape): def call(self, x=None): return self.w + + +class ScaleLevelEmbedding(tf.keras.layers.Layer): + def __init__(self, num_level, embed_shape, **kwargs): + super().__init__(**kwargs) + self.embed_shape = embed_shape + self.num_level = num_level + + def build(self, input_shape): + self.w = self.add_weight(name='kernel', shape=(self.num_level, self.embed_shape), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.0), trainable=True) + + super().build(input_shape) + + def call(self, x=None): + return self.w + + +class MLP(tf.keras.layers.Layer): + def __init__(self, hidden_dim, output_dim, kernel_initializer=tf.keras.initializers.GlorotUniform(), bias_initializer=tf.keras.initializers.GlorotUniform(), **kwargs): + super().__init__(**kwargs) + + self.layer_0 = Linear(hidden_dim, name='layer_0') + self.layer_1 = Linear(hidden_dim, name='layer_1') + self.layer_2 = Linear(output_dim, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, name='layer_2') + + + def call(self, x, training=False): + x = tf.nn.relu(self.layer_0(x)) + x = tf.nn.relu(self.layer_1(x)) + x = self.layer_2(x) + + return x + diff --git a/detr_tf/networks/deformable_detr.py b/detr_tf/networks/deformable_detr.py new file mode 100644 index 00000000..9bc51c00 --- /dev/null +++ b/detr_tf/networks/deformable_detr.py @@ -0,0 +1,342 @@ +import pickle +import tensorflow as tf +import numpy as np +import time +import cv2 +import matplotlib.pyplot as plt +import os +import math +import json +from pathlib import Path +import tensorflow_addons as tfa +import functools +import collections + +from detr_tf.networks.deformable_transformer import DeformableTransformer +from detr_tf.networks.transformer import MultiHeadAttention +from detr_tf.networks.resnet_backbone import ResNet50Backbone +from detr_tf.networks.custom_layers import Linear, FixedEmbedding, ScaleLevelEmbedding, MLP +from detr_tf.networks.position_embeddings import PositionEmbeddingSine +from detr_tf.networks.transformer import Transformer +from detr_tf.networks.weights import load_weights + +class DeformableDETR(tf.keras.Model): + def __init__(self, + model_dim=256, + num_classes=91, + num_queries=300, + num_sampling_points=4, + backbone=None, + pos_encoder=None, + transformer=None, + num_encoder_layers=6, + num_decoder_layers=6, + return_intermediate_dec=True, + init_query_embedding=False, + batch_size=None, + use_mask_bn=False, + refine_bbox=True, + multiscale=True, + train_encoder=False, + **kwargs): + super().__init__(**kwargs) + self.num_queries = num_queries + + self.backbone = ResNet50Backbone(name='backbone') + + self.pos_encoder = pos_encoder or PositionEmbeddingSine( + num_pos_features=model_dim // 2, normalize=True, center=True) + + self.query_embed = FixedEmbedding((num_queries, model_dim*2), name='query_embed') + self.level_embed = ScaleLevelEmbedding(4, model_dim, name="level_embed", trainable=train_encoder) + + + self.multiscale = multiscale + + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed = list( + Linear(num_classes, + bias_initializer=tf.keras.initializers.Constant(bias_value), + name=f'class_embed_{i}') + for i in range(num_decoder_layers)) + + self.bbox_embed = list( + MLP(model_dim, + 4, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name=f'bbox_embed_{i}') + for i in range(num_decoder_layers)) + + # hack to force shared weight (different from pytorch cloning approach) + if not refine_bbox: + self.class_embed = [self.class_embed[0] for _ in range(num_decoder_layers)] + self.bbox_embed = [self.bbox_embed[0] for _ in range(num_decoder_layers)] + + self.transformer = transformer or DeformableTransformer( + query_embed_layer=self.query_embed, + level_embed=self.level_embed, + layer_position_embedding_sine=self.pos_encoder, + model_dim=model_dim, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + num_sampling_points=num_sampling_points, + return_intermediate_dec=return_intermediate_dec, + init_query_embedding=init_query_embedding, + class_embed=self.class_embed, + bbox_embed=self.bbox_embed, + refine_bbox=refine_bbox, + train_encoder=train_encoder, + name='transformer' + ) + self.model_dim = model_dim + + layer_norm = functools.partial(tfa.layers.GroupNormalization, groups=32, epsilon=1e-05) #tf.keras.layers.BatchNormalization + + self.input_proj_0 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/0/0', trainable=train_encoder) + self.input_proj_gn_0 = layer_norm(name="input_proj_gn/0/1", trainable=train_encoder) + + if multiscale: + self.input_proj_1 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/1/0', trainable=train_encoder) + self.input_proj_2 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/2/0', trainable=train_encoder) + self.input_proj_3 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=3, strides=2, name='input_proj/3/0', trainable=train_encoder) + + self.input_proj_gn_1 = layer_norm(name="input_proj_gn/1/1", trainable=train_encoder) + self.input_proj_gn_2 = layer_norm(name="input_proj_gn/2/1", trainable=train_encoder) + self.input_proj_gn_3 = layer_norm(name="input_proj_gn/3/1", trainable=train_encoder) + + #self.activation = tf.keras.layers.ReLU() + + self.num_decoder_layers = num_decoder_layers + + + def call(self, inp, training=False, post_process=False): + x = inp + backbone_outputs = self.backbone(x, training=training) + x2, x1, x0, _ = backbone_outputs + + if self.multiscale: + src_proj_outputs = [self.input_proj_gn_0(self.input_proj_0(x0)), \ + self.input_proj_gn_1(self.input_proj_1(x1)), \ + self.input_proj_gn_2(self.input_proj_2(x2)), \ + self.input_proj_gn_3(tf.keras.layers.ZeroPadding2D(1)(self.input_proj_3(x2)))] + else: + src_proj_outputs = [self.input_proj_gn_0(self.input_proj_0(x2))] + + masks = list(tf.zeros([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1], tf.float32) for x in src_proj_outputs) + + decoder, encoder, outputs_coord = self.transformer(src_proj_outputs, + masks, + training=training) + + outputs_class = list(class_embed(x) for class_embed, x in zip(self.class_embed, tf.split(decoder, self.num_decoder_layers)) ) + + output = {'bbox_pred_logits': outputs_class[-1], + 'bbox_pred_boxes': outputs_coord[-1]} + + if post_process: + output = self.post_process(output) + + a = self.query_embed(None) + + return output + + +class DetrClassHead(tf.keras.layers.Layer): + + def __init__(self, detr, include_top, nb_class=None, refine_bbox=False, **kwargs): + """ + """ + super().__init__(name="detr_class_head", **kwargs) + self.include_top = include_top + if self.include_top: + if refine_bbox: + self.layer_class_embed = list(detr.get_layer(f'class_embed_{i}') for i in range(detr.num_decoder_layers)) + else: + #shared weights + self.layer_class_embed = list(detr.get_layer(f'class_embed_0') for _ in range(detr.num_decoder_layers) ) + else: + # Setup the new layers + if refine_bbox: + self.layer_class_embed = list(tf.keras.layers.Dense(nb_class, name=f"class_embed_{i}") for i in range(detr.num_decoder_layers) ) + else: + layer = tf.keras.layers.Dense(nb_class, name=f"class_embed_0") + self.layer_class_embed = list(layer for i in range(detr.num_decoder_layers) ) + + def call(self, decoder_state): + outputs = {} + + # Output class + outputs_class = [l(s) for l, s in zip(self.layer_class_embed, tf.unstack(decoder_state, axis=0))] + + outputs = {'bbox_pred_logits': outputs_class[-1]} + + outputs["bbox_aux"] = [] + for out_class in outputs_class: + outputs["bbox_aux"].append({ + "bbox_pred_logits": out_class + }) + + return outputs + + + def build(self, input_shape=None, **kwargs): + super().build(input_shape, **kwargs) + + + +def get_detr_core(detr, backbone, model_dim, tf_backbone=False, multiscale=True): + """ DETR Core is made of the backbone and the transformer part without the + heads + """ + + layer_transformer = detr.get_layer("transformer") + + #### Set ops + if not tf_backbone: + image_input = tf.keras.Input((None, None, 3)) + backbone_outputs = backbone(image_input) + x2, x1, x0, _ = backbone_outputs + else: + image_input = backbone.inputs + _ = backbone.get_layer("conv1_relu").output #/2 + _ = backbone.get_layer("conv2_block3_out").output #/4 + x0 = backbone.get_layer("conv3_block4_out").output #/8 + x1 = backbone.get_layer("conv4_block6_out").output #/16 + x2 = backbone.get_layer("conv5_block3_out").output #/32 + backbone_outputs = x2, x1, x0, _ + + if multiscale: + src_proj_outputs = list((None,None, None, None)) + for i, tensor in enumerate([x0, x1, x2, x2]): + input_proj_layer = detr.get_layer(f'input_proj/{i}/0') + input_proj_gn_layer = detr.get_layer(f'input_proj_gn/{i}/1') + if i == 3: + tensor = tf.keras.layers.ZeroPadding2D(1)(tensor) + tensor = input_proj_layer(tensor) + + src_proj_outputs[i] = input_proj_gn_layer(tensor) + else: + input_proj_layer = detr.get_layer(f'input_proj/0/0') + input_proj_gn_layer = detr.get_layer(f'input_proj_gn/0/1') + src_proj_outputs = [input_proj_gn_layer(input_proj_layer(x2))] + + masks = list(tf.zeros([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1], tf.float32) for x in src_proj_outputs) + + decoder, encoder, outputs_coord = layer_transformer(src_proj_outputs, masks) + + detr = tf.keras.Model(image_input, [outputs_coord, decoder, encoder, src_proj_outputs, backbone_outputs], name="detr_core") + + return detr + + +def get_deformable_detr_model( + config, + + include_top=False, + include_bbox=True, + nb_class=None, + weights=None, + tf_backbone=False, + + batch_size=None, + num_decoder_layers=6, + num_encoder_layers=6, + + use_mask_bn=False, + + + refine_bbox=False, + return_intermediate_dec=True, + model_dim=256, + multiscale=True, + include_bbox_3d=False, + bbox_3d_config=None, + train_encoder=True, + + ): + if weights == "deformable-detr-refine_bbox" and nb_class is not None and include_top and nb_class != 91: + raise ValueError('"deformable_detr" weights are trained with 92 outputs. Do not include the network top to set this number of class') + elif weights == "deformable_detr" and nb_class is None: + nb_class = 91 + + if weights != "deformable-detr-refine_bbox" and refine_bbox and weights is not None: + raise ValueError('"Trying to instanciate deformable_detr_bbox_refined with deformable_detr weights') + + init_query_embedding = False #if weights == "deformable_detr" else True + # Load model and weights + detr = DeformableDETR(num_classes=nb_class, + num_decoder_layers=num_decoder_layers, + num_encoder_layers=num_encoder_layers, + batch_size=batch_size, + init_query_embedding=init_query_embedding, + use_mask_bn=use_mask_bn, + + refine_bbox=refine_bbox, + return_intermediate_dec=return_intermediate_dec, + + multiscale=multiscale, + train_encoder=train_encoder) + + image_shape = (None, None, 3) + + # Backbone + if not tf_backbone: + backbone = detr.get_layer("backbone") + else: + config.normalized_method = "tf_resnet" + backbone = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_shape=(None, None, 3)) + + if weights is not None: + load_weights(detr, weights) + + # Backbone + if not tf_backbone: + backbone = detr.get_layer("backbone") + else: + backbone = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_shape=image_shape) + + # Get detr core: backbone + transformer + image_input = tf.keras.Input(image_shape, batch_size=batch_size) + + detr_core_outputs = get_detr_core(detr, backbone, model_dim, tf_backbone=tf_backbone, multiscale=multiscale)(image_input) + + if include_top is False and nb_class is None: + return tf.keras.Model(image_input, detr_core_outputs, name="detr_core") + + + outputs_coord, decoder_state, encoder_state, src_proj_output, backbone_outs = detr_core_outputs + + outputs = {"backbone_outs":list(backbone_outs), "src_proj_output":list(src_proj_output), "encoder_state":encoder_state} + + if include_bbox: + + outputs['bbox_pred_boxes'] = outputs_coord[-1] + outputs["bbox_aux"] = [] + for i in range(0, outputs_coord.shape[0] - 1): + outputs["bbox_aux"].append({ + "bbox_pred_boxes": outputs_coord[i] + }) + + # Add bbox head + class_head = DetrClassHead(detr, include_top=include_top, nb_class=nb_class, refine_bbox=refine_bbox) + bbox_outputs = class_head(decoder_state) + config.add_heads([class_head]) + update(outputs, bbox_outputs) + + deformable_detr = tf.keras.Model(image_input, outputs, name="deformable_detr") + return deformable_detr + +def update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = update(d.get(k, {}), v) + else: + d[k] = v + return d + + +if __name__ == "__main__": + main() diff --git a/detr_tf/networks/deformable_transformer.py b/detr_tf/networks/deformable_transformer.py new file mode 100644 index 00000000..4086c4c7 --- /dev/null +++ b/detr_tf/networks/deformable_transformer.py @@ -0,0 +1,583 @@ +import tensorflow as tf +from tensorflow.keras.layers import Dropout, Activation, LayerNormalization +import math +from .custom_layers import Linear +from .transformer import MultiHeadAttention + + +USE_CUDA_MS_DEFORM_IM2COL = True + +if USE_CUDA_MS_DEFORM_IM2COL: + from detr_tf.custom_ops.ms_deform_attn import ms_deform_im2col +else: + from detr_tf.custom_ops.ms_deform_attn.ms_deform_attn import MSDeformAttnFunction + +class DeformableTransformer(tf.keras.layers.Layer): + def __init__(self, + layer_position_embedding_sine, + level_embed, + class_embed, + bbox_embed, + query_embed_layer=None, + model_dim=256, + num_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + num_sampling_points=4, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=False, + init_query_embedding=False, + use_track_query=False, + refine_bbox=False, + multiscale=True, + train_encoder=True, + **kwargs): + + super().__init__(**kwargs) + + self.model_dim = model_dim + self.num_heads = num_heads + + self.layer_position_embedding_sine = layer_position_embedding_sine + self.query_embed_layer = query_embed_layer + + self.level_embed = level_embed + + self.class_embed = class_embed + self.bbox_embed = bbox_embed + + self.init_query_embedding = init_query_embedding + + self.multiscale = multiscale + + self.encoder = DeformableEncoder(model_dim, num_heads, dim_feedforward, + dropout, activation, + num_encoder_layers, num_sampling_points=num_sampling_points, name='encoder', trainable=train_encoder) + + self.decoder = DeformableDecoder(class_embed, bbox_embed, model_dim, num_heads, dim_feedforward, + dropout, activation, + num_decoder_layers, + name='decoder', + num_sampling_points=num_sampling_points, refine_bbox=refine_bbox, + return_intermediate=return_intermediate_dec, use_track_query=use_track_query) + + + if self.init_query_embedding: + raise NotImplementedError() + self.query_encoding = self.add_weight(name='query_embedding', shape=(100, 256), + initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + else: + self.init_query_embedding = init_query_embedding + + self.reference_points = Linear(2, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name="reference_points") + + def get_reference_points(self, spatial_shapes): + reference_points_list = [] + for lvl, (H_W_) in enumerate(spatial_shapes): + H_, W_ = tf.unstack(H_W_) + ref_y, ref_x = tf.meshgrid(tf.linspace(0.5, tf.cast(H_, tf.float32) - 0.5, H_), + tf.linspace(0.5, tf.cast(W_, tf.float32) - 0.5, W_), indexing='ij') + + ref_y = ref_y / tf.cast(H_, tf.float32) + ref_x = ref_x / tf.cast(W_, tf.float32) + + ref = tf.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + # L, (H, W, 2) + return reference_points_list + + def call(self, source, mask, track_query=None, track_query_mask=None, track_query_reference_points=None, training=False): + + N = tf.shape(source[0])[0] + + + pos_encoding = list(self.layer_position_embedding_sine(m) for m in mask) + + if self.init_query_embedding: + # Panoptic embedding + query_encoding = self.query_encoding + else: + # Detection embedding + query_encoding = self.query_embed_layer(None) + + query_encoding, target = tf.split(query_encoding, 2, axis=1) + query_encoding = tf.expand_dims(query_encoding, axis=1) + + if self.level_embed is not None: + level_embed = self.level_embed(None) + lvl_pos_embed = list(level_embed[lvl, None, :] + tf.reshape(p, (N, -1, self.model_dim) ) for lvl, p in enumerate(pos_encoding) ) # N, (H*W), C + 1, 1, C + lvl_pos_embed_flatten = tf.concat(lvl_pos_embed, axis=1) # N, sum_L(H*W), C + else: + lvl_pos_embed_flatten = None + + # L, 2 + input_spatial_shapes = list(tf.shape(src)[1:3] for src in source) + encoder_reference_points = self.get_reference_points(input_spatial_shapes) + input_spatial_shapes = tf.stack(input_spatial_shapes, 0) + + # L, + input_level_start_index = tf.math.reduce_prod(input_spatial_shapes, axis=1) + input_level_start_index = tf.math.cumsum(input_level_start_index, axis=0, exclusive=True) + + # L, (H, W, 2) -> L, (1, H*W, 2) + encoder_reference_points = list( tf.reshape(rp_l, (1, -1, 2)) for rp_l in encoder_reference_points) + encoder_reference_points = tf.concat(encoder_reference_points, axis=1) + + + + #Flatten sources + source = [tf.reshape(s, (N, -1, self.model_dim) ) for s in source] + source = tf.concat(source, axis=1) + memory = self.encoder(source, encoder_reference_points, source_key_padding_mask=mask, + pos_encoding=lvl_pos_embed_flatten, + training=training, + source_spatial_shapes=input_spatial_shapes, + source_level_start_index=input_level_start_index) + + + decoder_reference_points = tf.math.sigmoid(self.reference_points(query_encoding)) + decoder_reference_points = tf.tile(decoder_reference_points, [1, N, 1]) + + if track_query_reference_points is not None: + decoder_reference_points = tf.concat([track_query_reference_points, decoder_reference_points], axis=0) + + target = tf.reshape(target, (300, 1, self.model_dim) ) + target = tf.tile(target, [1, N, 1]) + + + hs, reference_points = self.decoder(target, memory, decoder_reference_points, memory_key_padding_mask=mask, + pos_encoding=lvl_pos_embed_flatten, query_encoding=query_encoding, + track_query=track_query, track_query_mask=track_query_mask, + memory_spatial_shapes=input_spatial_shapes, + memory_level_start_index=input_level_start_index, + training=training) + + return tf.transpose(hs, [0, 2, 1, 3]), tf.transpose(memory, (1, 0, 2)), tf.transpose(reference_points, [0, 2, 1, 3]) + + +class DeformableEncoder(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, + dropout=0.1, activation='relu', + num_encoder_layers=6, num_sampling_points=4, **kwargs): + super().__init__(**kwargs) + + self.enc_layers = [DeformableEncoderLayer(model_dim, num_heads, num_sampling_points, dim_feedforward, + dropout, activation, + name='layer_%d'%i) + for i in range(num_encoder_layers)] + + + def call(self, source, reference_points, mask=None, source_key_padding_mask=None, + pos_encoding=None, track_query=None, + source_spatial_shapes=None, source_level_start_index=None, training=False): + x = source + + for l_id, layer in enumerate(self.enc_layers): + x = layer(x, reference_points, source_mask=mask, source_key_padding_mask=source_key_padding_mask, + pos_encoding=pos_encoding, input_spatial_shapes=source_spatial_shapes, input_level_start_index=source_level_start_index, training=training) + + return x + + +class DeformableDecoder(tf.keras.layers.Layer): + def __init__(self, class_embed, bbox_embed, model_dim=256, num_heads=8, dim_feedforward=2048, + dropout=0.1, activation='relu', + num_decoder_layers=6, num_sampling_points=4, return_intermediate=False, use_track_query=False, refine_bbox=False, **kwargs): + super().__init__(**kwargs) + + self.dec_layers = [DeformableDecoderLayer(model_dim, num_heads, num_sampling_points, dim_feedforward, + dropout, activation, use_track_query=use_track_query, + name='layer_%d'%i) + for i in range(num_decoder_layers)] + + self.class_embed = class_embed + self.bbox_embed = bbox_embed + + self.refine_bbox = refine_bbox + self.return_intermediate = return_intermediate + + + def call(self, target, memory, reference_points, target_mask=None, memory_mask=None, + target_key_padding_mask=None, memory_key_padding_mask=None, memory_spatial_shapes=None, memory_level_start_index=None, + pos_encoding=None, query_encoding=None, track_query=None, track_query_mask=None, training=False): + + + x = target + intermediate = [] + intermediate_reference_points = [] + + new_reference_points = reference_points + + for l_id, layer in enumerate(self.dec_layers): + # if the tracking is not use + # track_query we'll simply stay None. + x, track_query = layer(x, memory, reference_points, + target_mask=target_mask, + memory_mask=memory_mask, + target_key_padding_mask=target_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos_encoding=pos_encoding, + track_query=track_query, + track_query_mask=track_query_mask, + query_encoding=query_encoding, + memory_spatial_shapes=memory_spatial_shapes, + memory_level_start_index=memory_level_start_index) + + if track_query is not None: + out = tf.concat([track_query, x], axis=0) + else: + out = x + + tmp = self.bbox_embed[l_id](out) + if self.refine_bbox: + new_reference_points = inverse_sigmoid(new_reference_points) + else: + new_reference_points = inverse_sigmoid(reference_points) + + if new_reference_points.shape[-1] == 4: + new_reference_points = tmp + new_reference_points + elif new_reference_points.shape[-1] == 2: + xy = tmp[..., :2] + new_reference_points + hw = tmp[..., 2:] + new_reference_points = tf.concat([xy, hw], axis=-1) + else: + raise ValueError() + + + new_reference_points = tf.math.sigmoid(new_reference_points) + + if self.refine_bbox: + reference_points = tf.stop_gradient(new_reference_points) + + if self.return_intermediate: + intermediate.append(out) + intermediate_reference_points.append(new_reference_points) + + + if self.return_intermediate: + return tf.stack(intermediate, axis=0), tf.stack(intermediate_reference_points, axis=0) + + return out, reference_points + + +class DeformableEncoderLayer(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, num_sampling_points=4, dim_feedforward=2048, + dropout=0.1, activation='relu', + **kwargs): + super().__init__(**kwargs) + + self.self_attn = MSDeformableAttention(model_dim, num_heads, num_sampling_points, dropout=dropout, + name='self_attn') + + self.dropout = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + self.activation = Activation(activation) + + self.model_dim = model_dim + + self.linear1 = Linear(dim_feedforward, name='linear1') + self.linear2 = Linear(model_dim, name='linear2') + + self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') + self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') + + + def call(self, source, reference_points, source_mask=None, source_key_padding_mask=None, + pos_encoding=None, input_spatial_shapes=None, input_level_start_index=None, training=False): + """ + :param source (N, H, W, C) + :param pos_encoding (1, sum_L(H*W), C) + :param reference_points (H, W, 2) + + :return output (N, sum_L(H*W), C) + """ + + N = tf.shape(source[0])[0] + C = self.model_dim + + if pos_encoding is None: + query = source + else: + query = source + pos_encoding + + # # Multi-scale level embedding L, (N, H, W, C) + # query = list(q + level_embed[lvl, None, None, :] for lvl, q in enumerate(query)) + + # # Flatten ¤¤¤¤ + # # L, (N, H*W, C) + # query = list(tf.reshape(q, (N, -1, C) ) for q in query) + # # (N, sum_L{H_*W_}, C) + # query = tf.concat(query, axis=1) + + # (N, Length_{query}, C) + + #print("query", query.shape) + attn_source = self.self_attn(query, reference_points, source, input_spatial_shapes, input_level_start_index) + #print("attn_source", attn_source.shape) + # src = list(tf.reshape(s, (N, -1, C) ) for s in source) + # # (N, sum_L{H_*W_}, C) + # src = tf.concat(src, axis=1) + + source += self.dropout(attn_source, training=training) + source = self.norm1(source) + + #forward_ffn + x = self.linear1(source) + x = self.activation(x) + x = self.dropout2(x, training=training) + x = self.linear2(x) + source += self.dropout3(x, training=training) + source = self.norm2(source) + + # #Unflatten ¤¤¤¤ + # split_size = list(iss[0]*iss[1] for iss in input_spatial_shapes) + # # L, (N, H*W, 2) + # src = tf.split(src, split_size, axis=1) + # # L, (N, H, W, 2) + # src = list(tf.reshape(el, (N, iss[0], iss[1], C) ) for iss, el in zip(input_spatial_shapes, src)) + + return source + + + +class DeformableDecoderLayer(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, num_sampling_points=4, dim_feedforward=2048, + dropout=0.1, activation='relu', use_track_query=False, + **kwargs): + super().__init__(**kwargs) + + self.self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout, + name='self_attn') + self.cross_attn = MSDeformableAttention(model_dim, num_heads, dropout=dropout, + name='cross_attn') + + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + self.dropout4 = Dropout(dropout) + + self.activation = Activation(activation) + + self.linear1 = Linear(dim_feedforward, name='linear1') + self.linear2 = Linear(model_dim, name='linear2') + + self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') + self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') + self.norm3 = LayerNormalization(epsilon=1e-5, name='norm3') + + self.use_track_query = use_track_query + + # if self.use_track_query: + # self.dropout = Dropout(dropout) + # self.track_query_norm = LayerNormalization(epsilon=1e-5, name='track_query_norm') + # self.track_query_self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout, name='track_query_self_attn') + + + def call(self, target, memory, reference_points, target_mask=None, memory_mask=None, + target_key_padding_mask=None, memory_key_padding_mask=None, memory_spatial_shapes=None, memory_level_start_index=None, + pos_encoding=None, track_query=None, track_query_mask=None, query_encoding=None, level_embed=None, training=False): + + + if track_query is not None: + # track_query_query = track_query + # track_query_key = track_query + # track_query_target = track_query + + if target_key_padding_mask is None: + target_shape = tf.shape(target) + target_key_padding_mask = tf.zeros((target_shape[1], target_shape[0])) + + # track_query_attn_target = self.track_query_self_attn((track_query_query, track_query_key, track_query_target), key_padding_mask=track_query_mask, need_weights=False) + + # track_query_target += self.dropout(track_query_attn_target, training=training) + # track_query_target = self.track_query_norm(track_query_target) + nb_trackquery = tf.shape(track_query)[0] + + # Pad on the left the original query + query_encoding = tf.pad(query_encoding, [[nb_trackquery, 0], [0, 0], [0, 0]], "CONSTANT" ) + # Concat with the track query on the left + target = tf.concat([track_query, target], axis=0) + target_key_padding_mask = tf.concat([track_query_mask, target_key_padding_mask], axis=1) + + # If we use the track query, the query encoding is now padded with zeros for the track queries + # query_tgt = target + query_encoding + query_tgt = key_tgt = target + query_encoding + + attn_target = self.self_attn((query_tgt, key_tgt, target), attn_mask=target_mask, + key_padding_mask=target_key_padding_mask, + need_weights=False) + + target += self.dropout2(attn_target, training=training) + target = self.norm2(target) + + query_tgt = target + query_encoding + + query_tgt = tf.transpose(query_tgt, (1, 0, 2) ) + reference_points = tf.transpose(reference_points, (1, 0, 2) ) + + attn_target2 = self.cross_attn(query_tgt, reference_points, memory, + input_spatial_shapes=memory_spatial_shapes, input_level_start_index=memory_level_start_index) + attn_target2 = tf.transpose(attn_target2, (1, 0, 2) ) + + target += self.dropout1(attn_target2, training=training) + target = self.norm1(target) + + x = self.linear1(target) + x = self.activation(x) + x = self.dropout3(x, training=training) + x = self.linear2(x) + target += self.dropout4(x, training=training) + target = self.norm3(target) + + if track_query is not None: + n_track_query = target[:nb_trackquery] + target = target[nb_trackquery:] + return target, n_track_query + else: + return target, None + + +class SamplingOffsetBiasInitializer(tf.keras.initializers.Initializer): + + def __init__(self, num_heads, num_level, n_points): + self.num_heads = num_heads + self.num_level = num_level + self.n_points = n_points + + def __call__(self, shape, dtype=None, **kwargs): + thetas = tf.range(self.num_heads, dtype=tf.float32) * (2.0 * math.pi / self.num_heads) + grid_init = tf.stack([tf.math.cos(thetas), tf.math.sin(thetas)], axis=-1) + grid_init = grid_init / tf.math.reduce_max(tf.abs(grid_init), axis=-1, keepdims=True)[0] + grid_init = tf.reshape(grid_init, (self.num_heads, 1, 1, 2) ) + # self.num_heads, self.num_level, self.n_points, 2 + grid_init = tf.tile(grid_init, (1, self.num_level, self.n_points, 1) ) + + scaling = tf.range(self.n_points, dtype = tf.float32) + 1.0 + scaling = tf.reshape(scaling, (1, 1, self.n_points , 1) ) + grid_init = grid_init * scaling + + grid_init = tf.reshape(grid_init, (-1,)) + + return grid_init + + + +class MSDeformableAttention(tf.keras.layers.Layer): + def __init__(self, model_dim, num_heads, num_sampling_points = 4, num_level=4, dropout=0.0, **kwargs): + super().__init__(**kwargs) + + self.model_dim = model_dim + self.num_heads = num_heads + + self.num_level = num_level + self.num_sampling_points = num_sampling_points + + assert model_dim % num_heads == 0 + self.head_dim = model_dim // num_heads + + self.dropout = Dropout(rate=dropout) + + self.im2col_step = 64 + + + def build(self, input_shapes): + + self.sampling_offsets = Linear(self.num_heads * self.num_level * self.num_sampling_points * 2, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=SamplingOffsetBiasInitializer(self.num_heads, self.num_level, self.num_sampling_points), + name="sampling_offsets") + + self.attention_weights = Linear(self.num_heads * self.num_level * self.num_sampling_points, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name="attention_weights") + + self.value_proj = Linear(self.model_dim, bias_initializer=tf.keras.initializers.Zeros(), name="value_proj") + + self.output_proj = Linear(self.model_dim, name="output_proj") + + + def call(self, query, reference_points, inputs, input_spatial_shapes=None, input_level_start_index=None, input_padding_mask=None, training=False): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, 4), add additional (w, h) to form reference boxes + + :param inputs (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + or + :param inputs lvl, (N, H_l, W_l, C) + + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + :return output (N, Length_{query}, C) + """ + + #debug purpose + unstack_size = list(iss[0]*iss[1] for iss in tf.unstack(input_spatial_shapes, axis=0)) + unstack_shape = list( (iss[0], iss[1]) for iss in tf.unstack(input_spatial_shapes, axis=0)) + + N, Len_q, C = tf.unstack(tf.shape(query)) + + N, Len_in, _ = tf.unstack(tf.shape(inputs)) + value = self.value_proj(inputs) + value = tf.reshape(value, (N, Len_in, self.num_heads, self.head_dim)) + + + sampling_offsets = self.sampling_offsets(query) + sampling_offsets = tf.reshape(sampling_offsets, (N, Len_q, self.num_heads, self.num_level, self.num_sampling_points, 2) ) + + + attention_weights = self.attention_weights(query) + attention_weights = tf.reshape(attention_weights, (N, Len_q, self.num_heads, self.num_level * self.num_sampling_points) ) + attention_weights = tf.nn.softmax(attention_weights, axis=-1) + attention_weights = tf.reshape(attention_weights, (N, Len_q, self.num_heads, self.num_level, self.num_sampling_points) ) + + + # (N, Len_q, num_heads, num_level, num_sampling_points, _) + if reference_points.shape[-1] == 2: + offset_normalizer = tf.cast(tf.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1), tf.float32) + sampling_locations = reference_points[:, :, None, None, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, None, None, :2] + sampling_offsets / self.num_sampling_points * reference_points[:, :, None, None, None, 2:] * 0.5 + else: + raise ValueError(f"reference_points shape must be defined, got {reference_points.shape[-1]}") + + if USE_CUDA_MS_DEFORM_IM2COL: + # Flatten and call custom op ! + output = ms_deform_im2col( + value, # (N, Len_in, n_heads, d_model#n_heads) + input_spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + input_level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + sampling_locations, # (N, Len_q, n_heads, n_levels, n_points, 2) + attention_weights # (N, Len_q, num_heads, n_level, num_sampling_points) + ) + else: + #Unflatten + value = tf.split(value, num_or_size_splits=unstack_size, axis=1) + value = list(tf.reshape(v, (N, shape[0], shape[1], self.num_heads, self.head_dim) ) for v, shape in zip(value, unstack_shape) ) + + sampling_loc = tf.unstack(sampling_locations, axis=3) #(N, Len_q, n_heads, num_sampling_points) + + output = MSDeformAttnFunction(value, sampling_loc, attention_weights) + + + output = self.output_proj(output) + + + return output + +def inverse_sigmoid(x, eps=1e-5): + x = tf.clip_by_value(x, 0.0, 1.0) + x1 = tf.clip_by_value(x, eps, 1.0) + + x2 = (1 - x) + x2 = tf.clip_by_value(x2, eps, 1.0) + return tf.math.log(x1/x2) diff --git a/detr_tf/networks/detr.py b/detr_tf/networks/detr.py index 7f2a202c..04f09156 100644 --- a/detr_tf/networks/detr.py +++ b/detr_tf/networks/detr.py @@ -100,7 +100,7 @@ def add_heads_nlayers(config, detr, nb_class): tf.keras.layers.Dense(256, activation="relu"), tf.keras.layers.Dense(4, activation="sigmoid"), ], name="pos_layer") - config.add_nlayers([cls_layer, pos_layer]) + config.add_heads([cls_layer, pos_layer]) transformer_output = detr(image_input) cls_preds = cls_layer(transformer_output) @@ -139,6 +139,7 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba load_weights(detr, weights) image_input = tf.keras.Input((None, None, 3)) + image_mask = tf.keras.Input((None, None, 1)) # Backbone if not tf_backbone: @@ -167,20 +168,23 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba bbox_embed_linear3 = detr.get_layer('bbox_embed_2') activation = detr.get_layer("re_lu") - x = backbone(image_input) + x, _, _, _ = backbone(image_input) + + # Resize the mask to the same size of the backbone outptu + masks = tf.image.resize(image_mask, (tf.shape(x)[1], tf.shape(x)[2]), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + masks = tf.cast(masks, tf.int32) - masks = tf.zeros((tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]), tf.bool) pos_encoding = position_embedding_sine(masks) hs = transformer(input_proj(x), masks, query_embed(None), pos_encoding)[0] - detr = tf.keras.Model(image_input, hs, name="detr") + detr = tf.keras.Model([image_input, image_mask], hs, name="detr") if include_top is False and nb_class is None: return detr elif include_top is False and nb_class is not None: return add_heads_nlayers(config, detr, nb_class) - transformer_output = detr(image_input) + transformer_output = detr((image_input, masks)) outputs_class = class_embed(transformer_output) box_ftmps = activation(bbox_embed_linear1(transformer_output)) @@ -201,5 +205,5 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba "pred_boxes": pred_boxes }) - return tf.keras.Model(image_input, output, name="detr_finetuning") + return tf.keras.Model([image_input, image_mask], output, name="detr_finetuning") diff --git a/detr_tf/networks/position_embeddings.py b/detr_tf/networks/position_embeddings.py index c567c380..bd9acb38 100644 --- a/detr_tf/networks/position_embeddings.py +++ b/detr_tf/networks/position_embeddings.py @@ -1,12 +1,10 @@ import numpy as np import tensorflow as tf - -class PositionEmbeddingSine(tf.keras.Model): - - +class PositionEmbeddingSine(tf.keras.layers.Layer): + # These are the default parameters used in the original project def __init__(self, num_pos_features=64, temperature=10000, - normalize=False, scale=None, eps=1e-6, **kwargs): + normalize=False, scale=None, eps=1e-6, center=False, **kwargs): super().__init__(**kwargs) self.num_pos_features = num_pos_features @@ -18,33 +16,63 @@ def __init__(self, num_pos_features=64, temperature=10000, scale = 2 * np.pi self.scale = scale self.eps = eps + self.center = center + def call(self, mask): - not_mask = tf.cast(~mask, tf.float32) - y_embed = tf.math.cumsum(not_mask, axis=1) - x_embed = tf.math.cumsum(not_mask, axis=2) + + not_mask = tf.cast(mask == 0, tf.float32) + + y_embed_mask = tf.cumsum(not_mask, axis=1) + x_embed_mask = tf.cumsum(not_mask, axis=2) + y_embed_mask = tf.squeeze(y_embed_mask, axis=-1) + x_embed_mask = tf.squeeze(x_embed_mask, axis=-1) + + #print("y_embed_mask", y_embed_mask.shape) + #print("x_embed_mask", x_embed_mask.shape) + + x = tf.range(tf.shape(mask)[2]) + 1 + y = tf.range(tf.shape(mask)[1]) + 1 + x_embed, y_embed = tf.meshgrid(x, y) + + x_embed = tf.expand_dims(x_embed, axis=0) + y_embed = tf.expand_dims(y_embed, axis=0) + + x_embed = tf.tile(x_embed, [tf.shape(mask)[0], 1, 1,]) + y_embed = tf.tile(y_embed, [tf.shape(mask)[0], 1, 1,]) + x_embed = tf.cast(x_embed, tf.float32) + y_embed = tf.cast(y_embed, tf.float32) + + #print('x_embed', x_embed.shape) + #print("y_embed", y_embed.shape) if self.normalize: - y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + if self.center: + y_embed = y_embed-0.5 + x_embed = x_embed-0.5 + y_embed = y_embed / (y_embed_mask[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed_mask[:, :, -1:] + self.eps) * self.scale dim_t = tf.range(self.num_pos_features, dtype=tf.float32) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_features) - pos_x = x_embed[..., tf.newaxis] / dim_t - pos_y = y_embed[..., tf.newaxis] / dim_t - + x_embed = tf.expand_dims(x_embed, axis=-1) + y_embed = tf.expand_dims(y_embed, axis=-1) + + pos_x = x_embed / dim_t + pos_y = y_embed / dim_t + pos_x = tf.stack([tf.math.sin(pos_x[..., 0::2]), tf.math.cos(pos_x[..., 1::2])], axis=4) pos_y = tf.stack([tf.math.sin(pos_y[..., 0::2]), tf.math.cos(pos_y[..., 1::2])], axis=4) - shape = [tf.shape(pos_x)[i] for i in range(3)] + [-1] pos_x = tf.reshape(pos_x, shape) pos_y = tf.reshape(pos_y, shape) pos_emb = tf.concat([pos_y, pos_x], axis=3) + return pos_emb diff --git a/detr_tf/networks/resnet_backbone.py b/detr_tf/networks/resnet_backbone.py index d81a5872..5e0fbf1f 100644 --- a/detr_tf/networks/resnet_backbone.py +++ b/detr_tf/networks/resnet_backbone.py @@ -25,11 +25,11 @@ def call(self, x): x = self.pad2(x) x = self.maxpool(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - return x + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + return l4, l3, l2, l1 class ResNet50Backbone(ResNetBase): diff --git a/detr_tf/networks/transformer.py b/detr_tf/networks/transformer.py index 08f34448..60c402a5 100644 --- a/detr_tf/networks/transformer.py +++ b/detr_tf/networks/transformer.py @@ -269,19 +269,6 @@ def build(self, input_shapes): - - #self.in_proj_weight = tf.Variable( - # tf.zeros((in_dim, self.model_dim), dtype=tf.float32), name='in_proj_kernel') - #self.in_proj_bias = tf.Variable(tf.zeros((in_dim,), dtype=tf.float32), - # name='in_proj_bias') - - #self.out_proj_weight = tf.Variable( - # tf.zeros((self.model_dim, self.model_dim), dtype=tf.float32), name='out_proj_kernel') - #self.out_proj_bias = tf.Variable( - # tf.zeros((self.model_dim,), dtype=tf.float32), name='out_proj_bias') - - - def call(self, inputs, attn_mask=None, key_padding_mask=None, need_weights=True, training=False): @@ -319,8 +306,10 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None, if attn_mask is not None: attn_output_weights += attn_mask - """ + if key_padding_mask is not None: + key_padding_mask = tf.cast(key_padding_mask, tf.bool) + attn_output_weights = tf.reshape(attn_output_weights, [batch_size, self.num_heads, target_len, source_len]) @@ -328,13 +317,13 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None, key_padding_mask = tf.expand_dims(key_padding_mask, 2) key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1]) + #print("before attn_output_weights", attn_output_weights.shape) attn_output_weights = tf.where(key_padding_mask, tf.zeros_like(attn_output_weights) + float('-inf'), attn_output_weights) attn_output_weights = tf.reshape(attn_output_weights, [batch_size * self.num_heads, target_len, source_len]) - """ attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1) diff --git a/detr_tf/networks/weights.py b/detr_tf/networks/weights.py index 87015558..c59f5cef 100644 --- a/detr_tf/networks/weights.py +++ b/detr_tf/networks/weights.py @@ -7,6 +7,16 @@ "https://storage.googleapis.com/visualbehavior-publicweights/detr/checkpoint", "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.data-00000-of-00001", "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.index" + ], + "deformable-detr": [ + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/checkpoint", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/deformable-detr.ckpt.data-00000-of-00001", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/deformable-detr.ckpt.index" + ], + "deformable-detr-refine_bbox": [ + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/checkpoint", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/deformable-detr-refine_bbox.ckpt.data-00000-of-00001", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/deformable-detr-refine_bbox.ckpt.index" ] } diff --git a/detr_tf/optimizers.py b/detr_tf/optimizers.py index ae1de914..c7f0a310 100644 --- a/detr_tf/optimizers.py +++ b/detr_tf/optimizers.py @@ -1,4 +1,5 @@ import tensorflow as tf +import tensorflow_addons as tfa def disable_batchnorm_training(model): for l in model.layers: @@ -7,20 +8,6 @@ def disable_batchnorm_training(model): elif isinstance(l, tf.keras.layers.BatchNormalization): l.trainable = False -def get_transformers_trainable_variables(model, exclude=[]): - transformers_variables = [] - - # Transformers variables - transformers_variables = model.get_layer("detr").get_layer("transformer").trainable_variables - - for layer in model.layers[2:]: - if layer.name not in exclude: - transformers_variables += layer.trainable_variables - else: - pass - - return transformers_variables - def get_backbone_trainable_variables(model): backbone_variables = [] @@ -36,11 +23,26 @@ def get_backbone_trainable_variables(model): return backbone_variables -def get_nlayers_trainables_variables(model, nlayers_names): - nlayers_variables = [] - for nlayer_name in nlayers_names: - nlayers_variables += model.get_layer(nlayer_name).trainable_variables - return nlayers_variables +def get_transformers_trainable_variables(model, exclude=[]): + transformers_variables = [] + + # Transformers variables + transformers_variables = model.get_layer("detr").get_layer("transformer").trainable_variables + + for layer in model.layers[2:]: + if layer.name not in exclude: + transformers_variables += layer.trainable_variables + else: + pass + + return transformers_variables + + +def get_heads_trainables_variables(model, heads_names): + heads_variables = [] + for nlayer_name in heads_names: + heads_variables += model.get_layer(nlayer_name).trainable_variables + return heads_variables def get_trainable_variables(model, config): @@ -49,19 +51,14 @@ def get_trainable_variables(model, config): backbone_variables = [] transformers_variables = [] - nlayers_variables = [] - + heads_variables = [] - # Retrieve the gradient ofr each trainable variables - #if config.train_backbone: + # The gradient will be retrieve for each trainable variable backbone_variables = get_backbone_trainable_variables(model) - #if config.train_transformers: - transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers) - #if config.train_nlayers: - nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers) + transformers_variables = get_transformers_trainable_variables(model, exclude=config.heads) + heads_variables = get_heads_trainables_variables(model, config.heads) - - return backbone_variables, transformers_variables, nlayers_variables + return backbone_variables, transformers_variables, heads_variables def setup_optimizers(model, config): @@ -76,59 +73,79 @@ def get_transformers_learning_rate(): return config.transformers_lr @tf.function - def get_nlayers_learning_rate(): - return config.nlayers_lr + def get_heads_learning_rate(): + return config.heads_lr + + @tf.function + def get_backbone_wd(): + return config.backbone_lr*config.backbone_wd + + @tf.function + def get_transformers_wd(): + return config.transformers_lr*config.transformers_wd + + @tf.function + def get_heads_wd(): + return config.heads_lr*config.heads_wd + + # Disable batch norm on the backbone disable_batchnorm_training(model) # Optimizers - backbone_optimizer = tf.keras.optimizers.Adam(learning_rate=get_backbone_learning_rate, clipnorm=config.gradient_norm_clipping) - transformers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_transformers_learning_rate, clipnorm=config.gradient_norm_clipping) - nlayers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_nlayers_learning_rate, clipnorm=config.gradient_norm_clipping) + backbone_optimizer = tfa.optimizers.AdamW( + learning_rate=get_backbone_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_backbone_wd + ) + transformers_optimizer = tfa.optimizers.AdamW( + learning_rate=get_transformers_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_transformers_wd + ) + heads_optimizer = tfa.optimizers.AdamW( + learning_rate=get_heads_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_heads_wd + ) # Set trainable variables - backbone_variables, transformers_variables, nlayers_variables = [], [], [] + backbone_variables, transformers_variables, heads_variables = [], [], [] backbone_variables = get_backbone_trainable_variables(model) - transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers) - nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers) + transformers_variables = get_transformers_trainable_variables(model, exclude=config.heads) + heads_variables = get_heads_trainables_variables(model, config.heads) return { "backbone_optimizer": backbone_optimizer, "transformers_optimizer": transformers_optimizer, - "nlayers_optimizer": nlayers_optimizer, + "heads_optimizer": heads_optimizer, "backbone_variables": backbone_variables, "transformers_variables": transformers_variables, - "nlayers_variables": nlayers_variables, + "heads_variables": heads_variables, } def gather_gradient(model, optimizers, total_loss, tape, config, log): - backbone_variables, transformers_variables, nlayers_variables = get_trainable_variables(model, config) - trainables_variables = backbone_variables + transformers_variables + nlayers_variables + backbone_variables, transformers_variables, heads_variables = get_trainable_variables(model, config) + trainables_variables = backbone_variables + transformers_variables + heads_variables gradients = tape.gradient(total_loss, trainables_variables) # Retrieve the gradients from the tap backbone_gradients = gradients[:len(optimizers["backbone_variables"])] transformers_gradients = gradients[len(optimizers["backbone_variables"]):len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"])] - nlayers_gradients = gradients[len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"]):] + heads_gradients = gradients[len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"]):] gradient_steps = {} gradient_steps["backbone"] = {"gradients": backbone_gradients} gradient_steps["transformers"] = {"gradients": transformers_gradients} - gradient_steps["nlayers"] = {"gradients": nlayers_gradients} + gradient_steps["heads"] = {"gradients": heads_gradients} log.update({"backbone_lr": optimizers["backbone_optimizer"]._serialize_hyperparameter("learning_rate")}) log.update({"transformers_lr": optimizers["transformers_optimizer"]._serialize_hyperparameter("learning_rate")}) - log.update({"nlayers_lr": optimizers["nlayers_optimizer"]._serialize_hyperparameter("learning_rate")}) + log.update({"heads_lr": optimizers["heads_optimizer"]._serialize_hyperparameter("learning_rate")}) return gradient_steps diff --git a/detr_tf/training.py b/detr_tf/training.py index edd23b22..f530403f 100644 --- a/detr_tf/training.py +++ b/detr_tf/training.py @@ -1,13 +1,38 @@ import tensorflow as tf +import matplotlib.pyplot as plt + from .optimizers import gather_gradient, aggregate_grad_and_apply from .logger.training_logging import valid_log, train_log from .loss.loss import get_losses import time import wandb + +def handle_data(data): + """ Create (TODO) a mask from the given ragged images. The mask will be use in the encoder/decode + attention. + """ + n_data = {} + for key in data: + n_data[key] = data[key] + + padding_mask = tf.ones_like(data["images"]) + # The following will add 0 on all padded part + padding_mask = padding_mask.to_tensor()[:,:,:,:1] + # Set one instead of zero on all paded part + padding_mask = tf.abs(padding_mask - 1) + + n_data["images"] = n_data["images"].to_tensor() + n_data["mask"] = padding_mask + + return n_data + + @tf.function -def run_train_step(model, images, t_bbox, t_class, optimizers, config): +def run_train_step(model, data, optimizers, config): + + n_data = handle_data(data) if config.target_batch is not None: gradient_aggregate = int(config.target_batch // config.batch_size) @@ -15,8 +40,8 @@ def run_train_step(model, images, t_bbox, t_class, optimizers, config): gradient_aggregate = 1 with tf.GradientTape() as tape: - m_outputs = model(images, training=True) - total_loss, log = get_losses(m_outputs, t_bbox, t_class, config) + m_outputs = model((n_data["images"], n_data["mask"]), training=True) + total_loss, log = get_losses(m_outputs, t_bbox=n_data["target_bbox"], t_class=n_data["target_class"], config=config, batch_size=config.batch_size) total_loss = total_loss / gradient_aggregate # Compute gradient for each part of the network @@ -26,9 +51,12 @@ def run_train_step(model, images, t_bbox, t_class, optimizers, config): @tf.function -def run_val_step(model, images, t_bbox, t_class, config): - m_outputs = model(images, training=False) - total_loss, log = get_losses(m_outputs, t_bbox, t_class, config) +def run_val_step(model, data, config, batch_size): + + n_data = handle_data(data) + + m_outputs = model((n_data["images"], n_data["mask"]), training=False) + total_loss, log = get_losses(m_outputs, t_bbox=n_data["target_bbox"], t_class=n_data["target_class"], config=config, batch_size=batch_size) return m_outputs, total_loss, log @@ -40,14 +68,16 @@ def fit(model, train_dt, optimizers, config, epoch_nb, class_names): if config.target_batch is not None: gradient_aggregate = int(config.target_batch // config.batch_size) t = None - for epoch_step , (images, t_bbox, t_class) in enumerate(train_dt): + for epoch_step , data in enumerate(train_dt): + + data = train_dt.itertuple2dict(data) # Run the prediction and retrieve the gradient step for each part of the network - m_outputs, total_loss, log, gradient_steps = run_train_step(model, images, t_bbox, t_class, optimizers, config) + m_outputs, total_loss, log, gradient_steps = run_train_step(model, data, optimizers, config) # Load the predictions if config.log: - train_log(images, t_bbox, t_class, m_outputs, config, config.global_step, class_names, prefix="train/") + train_log(handle_data(data), m_outputs, config, config.global_step, class_names, prefix="train/") # Aggregate and apply the gradient for name in gradient_steps: @@ -65,16 +95,18 @@ def fit(model, train_dt, optimizers, config, epoch_nb, class_names): config.global_step += 1 -def eval(model, valid_dt, config, class_name, evaluation_step=200): +def eval(model, valid_dt, config, class_name, evaluation_step=200, batch_size=None): """ Evaluate the model on the validation set """ + batch_size = config.batch_size if batch_size is None else batch_size t = None - for val_step, (images, t_bbox, t_class) in enumerate(valid_dt): + for val_step, data in enumerate(valid_dt): + data = valid_dt.itertuple2dict(data) # Run prediction - m_outputs, total_loss, log = run_val_step(model, images, t_bbox, t_class, config) + m_outputs, total_loss, log = run_val_step(model, data, config, batch_size) # Log the predictions if config.log: - valid_log(images, t_bbox, t_class, m_outputs, config, val_step, config.global_step, class_name, evaluation_step=evaluation_step, prefix="train/") + valid_log(handle_data(data), m_outputs, config, batch_size, val_step, config.global_step, class_name, evaluation_step=evaluation_step, prefix="train/") # Log the metrics if config.log and val_step == 0: wandb.log({f"val/{k}":log[k] for k in log}, step=config.global_step) diff --git a/detr_tf/training_config.py b/detr_tf/training_config.py index 4f884d7b..63a0adbb 100644 --- a/detr_tf/training_config.py +++ b/detr_tf/training_config.py @@ -19,18 +19,24 @@ def training_config_parser(): # What to train parser.add_argument("--train_backbone", action='store_true', required=False, default=False, help="Train backbone") parser.add_argument("--train_transformers", action='store_true', required=False, default=False, help="Train transformers") - parser.add_argument("--train_nlayers", action='store_true', required=False, default=False, help="Train new layers") + parser.add_argument("--train_heads", action='store_true', required=False, default=False, help="Train the model heads (For finetuning)") # How to train + parser.add_argument("--image_size", default=None, required=False, type=str) parser.add_argument("--finetuning", default=False, required=False, action='store_true', help="Load the model weight before to train") parser.add_argument("--batch_size", type=int, required=False, default=1, help="Batch size to use to train the model") parser.add_argument("--gradient_norm_clipping", type=float, required=False, default=0.1, help="Gradient norm clipping") parser.add_argument("--target_batch", type=int, required=False, default=None, help="When running on a single GPU, aggretate the gradient before to apply.") # Learning rate - parser.add_argument("--backbone_lr", type=bool, required=False, default=1e-5, help="Train backbone") - parser.add_argument("--transformers_lr", type=bool, required=False, default=1e-4, help="Train transformers") - parser.add_argument("--nlayers_lr", type=bool, required=False, default=1e-4, help="Train new layers") + parser.add_argument("--backbone_lr", type=float, required=False, default=1e-5, help="Backbone learning rate") + parser.add_argument("--transformers_lr", type=float, required=False, default=1e-4, help="Transformer learning rate") + parser.add_argument("--heads_lr", type=float, required=False, default=1e-4, help="Model heads learning rate") + + # Weight decay + parser.add_argument("--backbone_wd", type=float, required=False, default=1e-4, help="Backbone weight decay") + parser.add_argument("--transformers_wd", type=float, required=False, default=1e-4, help="Transformer weight decay") + parser.add_argument("--heads_wd", type=float, required=False, default=1e-4, help="Model heads weight decay") # Logging parser.add_argument("--log", required=False, action="store_true", default=False, help="Log into wandb") @@ -46,12 +52,16 @@ def __init__(self): self.data_dir, self.img_dir, self.ann_dir, self.ann_file = None, None, None, None self.data = DataConfig(data_dir=None, img_dir=None, ann_file=None, ann_dir=None) self.background_class = 0 - self.image_size = 376, 672 + + #self.image_size = 376, 672 + # If image size is None, then multi scale training will be used as + # described in the paper. + self.image_size = None # What to train self.train_backbone = False self.train_transformers = False - self.train_nlayers = False + self.train_heads = False # How to train self.finetuning = False @@ -65,8 +75,17 @@ def __init__(self): # keeping the same graph self.backbone_lr = tf.Variable(1e-5) self.transformers_lr = tf.Variable(1e-4) - self.nlayers_lr = tf.Variable(1e-4) - self.nlayers = [] + self.heads_lr = tf.Variable(1e-4) + + # Weidht decay + # Set as tf.Variable so that the variable can be update during the training while + # keeping the same graph + self.backbone_wd = tf.Variable(1e-4) + self.transformers_wd = tf.Variable(1e-4) + self.heads_wd = tf.Variable(1e-4) + + # Heads layer list + self.heads = [] # Training progress self.global_step = 0 @@ -76,10 +95,10 @@ def __init__(self): self.normalized_method = "torch_resnet" - def add_nlayers(self, layers): + def add_heads(self, layers): """ Set the new layers to train on the training config """ - self.nlayers = [l.name for l in layers] + self.heads = [l.name for l in layers] def update_from_args(self, args): @@ -87,14 +106,16 @@ def update_from_args(self, args): """ args = vars(args) for key in args: - if isinstance(getattr(self, key), tf.Variable): + if isinstance(getattr(self, key, None), tf.Variable): getattr(self, key).assign(args[key]) else: setattr(self, key, args[key]) - # Set the config on the data class - + if self.image_size is not None: + img_size = self.image_size.split(",") + self.image_size = (int(img_size[0]), int(img_size[1])) + # Set the config on the data class self.data = DataConfig( data_dir=self.data_dir, img_dir=self.img_dir, diff --git a/eval.py b/eval.py index 98c6189b..93b9bc7c 100644 --- a/eval.py +++ b/eval.py @@ -14,6 +14,11 @@ from detr_tf.bbox import xcycwh_to_xy_min_xy_max, xcycwh_to_yx_min_yx_max from detr_tf.inference import numpy_bbox_to_image from detr_tf.training_config import TrainingConfig, training_config_parser +from detr_tf.training import handle_data + +tf.random.set_seed(40) +np.random.seed(40) + def build_model(config): @@ -27,6 +32,12 @@ def build_model(config): return detr +#@tf.function +def run_model(data, model): + n_data = handle_data(data) + return model((n_data["images"], n_data["mask"])) + + def eval_model(model, config, class_names, valid_dt): """ Run evaluation """ @@ -38,18 +49,26 @@ def eval_model(model, config, class_names, valid_dt): } it = 0 - for images, target_bbox, target_class in valid_dt: + for data in valid_dt: + data = valid_dt.itertuple2dict(data) + # Forward pass - m_outputs = model(images) + m_outputs = run_model(data, model) + # Run predictions p_bbox, p_labels, p_scores = get_model_inference(m_outputs, config.background_class, bbox_format="yxyx") + # Remove padding - t_bbox, t_class = target_bbox[0], target_class[0] - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) + t_bbox, t_class = data["target_bbox"][0], data["target_class"][0] + t_bbox = xcycwh_to_yx_min_yx_max(t_bbox) - t_class = tf.slice(t_class, [1, 0], [size, -1]) t_class = tf.squeeze(t_class, axis=-1) + + # Filter undesired target + _filter = tf.squeeze(tf.where(t_class != -1), axis=-1) + t_class = tf.gather(t_class, _filter) + t_bbox = tf.gather(t_bbox, _filter) + # Compute map cal_map(p_bbox, p_labels, p_scores, np.zeros((138, 138, len(p_bbox))), np.array(t_bbox), np.array(t_class), np.zeros((138, 138, len(t_bbox))), ap_data, iou_thresholds) print(f"Computing map.....{it}", end="\r") @@ -73,7 +92,7 @@ def eval_model(model, config, class_names, valid_dt): # Load the model with the new layers to finetune detr = build_model(config) - valid_dt, class_names = load_coco_dataset(config, 1, augmentation=None) + valid_dt, class_names = load_coco_dataset(config, 1, augmentation=False, shuffle_data=False) # Run training eval_model(detr, config, class_names, valid_dt) diff --git a/finetune_coco.py b/finetune_coco.py index e1945918..44626e70 100644 --- a/finetune_coco.py +++ b/finetune_coco.py @@ -30,7 +30,7 @@ def build_model(config): """ Build the model with the pretrained weights. In this example we do not add new layers since the pretrained model is already trained on coco. - See examples/finetuning_voc.py to add new layers. + See the finetuning_voc.py script see an example on how to change the number of class on the last layer. """ # Load the pretrained model detr = get_detr_model(config, include_top=True, weights="detr") @@ -44,8 +44,10 @@ def run_finetuning(config): detr = build_model(config) # Load the training and validation dataset - train_dt, coco_class_names = load_coco_dataset("train", config.batch_size, config, augmentation=True) - valid_dt, _ = load_coco_dataset("val", 1, config, augmentation=False) + train_dt, coco_class_names = load_coco_dataset( + config, config.batch_size, augmentation=True, img_dir="val2017", ann_file="annotations/instances_val2017.json") + valid_dt, _ = load_coco_dataset( + config, 1, augmentation=False, img_dir="val2017", ann_file="annotations/instances_val2017.json") # Train/finetune the transformers only config.train_backbone = False @@ -56,7 +58,7 @@ def run_finetuning(config): # Run the training for 5 epochs for epoch_nb in range(100): - training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) + training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=100, batch_size=1) training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names) diff --git a/finetune_voc.py b/finetune_voc.py index 9bf56674..cccb1f99 100644 --- a/finetune_voc.py +++ b/finetune_voc.py @@ -31,14 +31,16 @@ ] def build_model(config): - """ Build the model with the pretrained weights - and add new layers to finetune + """ Build the model with the pretrained weights. + We set include_top to False to not include the last layer of the transformer. + Then, `nb_class` is used to automaticly replace the lat layers by new layers with the + appropriate number of target class. """ # Input image_input = tf.keras.Input((None, None, 3)) - - # Load the pretrained model - detr = get_detr_model(config, include_top=False, weights="detr", num_decoder_layers=6, num_encoder_layers=6) + # Load the pretrained model and replace the laster layers for this new task. + detr = get_detr_model(config, include_top=False, weights="detr", nb_class=len(VOC_CLASS_NAME)+1) + return detr # Setup the new layers cls_layer = tf.keras.layers.Dense(len(VOC_CLASS_NAME) + 1, name="cls_layer") @@ -68,7 +70,7 @@ def run_finetuning(config): detr = build_model(config) # Load the training and validation dataset (for the purpose of this example we're gonna load the training - # as the validation, but in practise you should have different folder loader for the training and the validation) + # as the validation, but in practise you should have different folder and loader for the training and the validation) train_dt, class_names = load_voc_dataset(config, config.batch_size, augmentation=True) valid_dt, _ = load_voc_dataset(config, 1, augmentation=False) @@ -95,8 +97,8 @@ def run_finetuning(config): config.transformers_lr.assign(1e-4) config.nlayers_lr.assign(1e-3) - training.eval(detr, valid_dt, config, class_names, evaluation_step=200) training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names) + training.eval(detr, valid_dt, config, class_names, evaluation_step=200) if __name__ == "__main__": diff --git a/train_coco.py b/train_coco.py index c4ac578c..b0d97818 100644 --- a/train_coco.py +++ b/train_coco.py @@ -48,9 +48,9 @@ def run_finetuning(config): # Load the training and validation dataset train_dt, coco_class_names = load_coco_dataset( - config, config.batch_size, augmentation=True, img_dir="train2017", ann_fil="annotations/instances_train2017.json") + config, config.batch_size, augmentation=True, img_dir="val2017", ann_file="annotations/instances_val2017.json") valid_dt, _ = load_coco_dataset( - config, 1, augmentation=False, img_dir="val2017", ann_fil="annotations/instances_val2017.json") + config, 1, augmentation=False, img_dir="val2017", ann_file="annotations/instances_val2017.json") # Train the backbone and the transformers # Check the training_config file for the other hyperparameters @@ -62,8 +62,8 @@ def run_finetuning(config): # Run the training for 100 epochs for epoch_nb in range(100): - training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names) + #training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) if __name__ == "__main__": diff --git a/webcam_inference.py b/webcam_inference.py index 63871187..2e26109d 100644 --- a/webcam_inference.py +++ b/webcam_inference.py @@ -3,19 +3,29 @@ import cv2 from detr_tf.training_config import TrainingConfig, training_config_parser + from detr_tf.networks.detr import get_detr_model +from detr_tf.networks.deformable_detr import get_deformable_detr_model + from detr_tf.data import processing from detr_tf.data.coco import COCO_CLASS_NAME from detr_tf.inference import get_model_inference, numpy_bbox_to_image + @tf.function -def run_inference(model, images, config): - m_outputs = model(images, training=False) - predicted_bbox, predicted_labels, predicted_scores = get_model_inference(m_outputs, config.background_class, bbox_format="xy_center") +def run_inference(model, images, config, use_mask=True): + + if use_mask: + mask = tf.zeros((1, images.shape[1], images.shape[2], 1)) + m_outputs = model((images, mask), training=False) + else: + m_outputs = model(images, training=False) + + predicted_bbox, predicted_labels, predicted_scores = get_model_inference(m_outputs, config.background_class, bbox_format="xy_center", threshold=0.2) return predicted_bbox, predicted_labels, predicted_scores -def run_webcam_inference(detr): +def run_webcam_inference(model, use_mask=True): cap = cv2.VideoCapture(0) @@ -27,7 +37,7 @@ def run_webcam_inference(detr): model_input = processing.normalized_images(model_input, config) # Run inference - predicted_bbox, predicted_labels, predicted_scores = run_inference(detr, np.expand_dims(model_input, axis=0), config) + predicted_bbox, predicted_labels, predicted_scores = run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask) frame = frame.astype(np.float32) frame = frame / 255 @@ -48,12 +58,24 @@ def run_webcam_inference(detr): tf.config.experimental.set_memory_growth(physical_devices[0], True) config = TrainingConfig() - args = training_config_parser().parse_args() + parser = training_config_parser() + + # Logging + parser.add_argument("model", type=str, help="One of 'detr', or 'deformable-detr'") + args = parser.parse_args() config.update_from_args(args) - # Load the model with the new layers to finetune - detr = get_detr_model(config, include_top=True, weights="detr") - config.background_class = 91 + if args.model == "detr": + print("Loading detr...") + # Load the model with the new layers to finetune + model = get_detr_model(config, include_top=True, weights="detr") + config.background_class = 91 + use_mask = True + elif args.model == "deformable-detr": + print("Loading deformable-detr...") + model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") + model.summary() + use_mask = False # Run webcam inference - run_webcam_inference(detr) + run_webcam_inference(model, use_mask=use_mask)