From e3cdd22a61ec5b809ce1bf90864ccc9a2ee52354 Mon Sep 17 00:00:00 2001 From: Thibault Date: Tue, 15 Jun 2021 09:42:38 +0200 Subject: [PATCH 1/5] WIP Multi-scale training + Mask handling on detr --- detr_tf/data/coco.py | 77 +++++++++--- detr_tf/data/processing.py | 10 +- detr_tf/data/tfcsv.py | 22 +++- detr_tf/data/transformation.py | 155 +++++++++++++++++------- detr_tf/data/voc.py | 32 +++-- detr_tf/logger/training_logging.py | 40 +++--- detr_tf/logger/wandb_logging.py | 2 +- detr_tf/loss/compute_map.py | 10 +- detr_tf/loss/hungarian_matching.py | 12 +- detr_tf/loss/loss.py | 11 +- detr_tf/networks/detr.py | 14 ++- detr_tf/networks/position_embeddings.py | 56 ++++++--- detr_tf/networks/transformer.py | 19 +-- detr_tf/optimizers.py | 105 +++++++++------- detr_tf/training.py | 58 +++++++-- detr_tf/training_config.py | 45 +++++-- eval.py | 37 ++++-- finetune_coco.py | 10 +- finetune_voc.py | 16 +-- train_coco.py | 6 +- 20 files changed, 507 insertions(+), 230 deletions(-) diff --git a/detr_tf/data/coco.py b/detr_tf/data/coco.py index e8f48557..d58d38ce 100644 --- a/detr_tf/data/coco.py +++ b/detr_tf/data/coco.py @@ -72,6 +72,12 @@ def get_coco_from_id(coco_id, coco, augmentation, config, img_dir): # Apply augmentations if len(t_bbox) > 0 and augmentation is not None: image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation) + + # If instance into the image, set at least one bbox with -1 everywhere + # This kind of bbox and class will be ignore at training + if len(t_bbox) == 0: t_bbox = np.zeros((1, 4)) - 1 + if len(t_class) == 0: t_class = np.zeros((1, 4)) - 1 + # Normalized images image = processing.normalized_images(image, config) # Set type for tensorflow @@ -79,17 +85,60 @@ def get_coco_from_id(coco_id, coco, augmentation, config, img_dir): t_bbox = t_bbox.astype(np.float32) t_class = t_class.astype(np.int64) is_crowd = np.array(is_crowd, dtype=np.int64) - return image, t_bbox, t_class, is_crowd + + return image, t_bbox, t_class#, is_crowd + + +def tensor_to_ragged(image, t_bbox, t_class): + # Images can have different size in multi-scale training + # Also, each image can have different number of instance. + # Therefore, we can use ragged tensor to handle Tensor with dynamic shapes. + # None is consider as Dynamic in the shape by the Ragged Tensor. + image.set_shape(tf.TensorShape([None, None, 3])) + image = tf.RaggedTensor.from_tensor(image).to_tensor() + t_bbox.set_shape(tf.TensorShape([None, 4])) + t_bbox = tf.RaggedTensor.from_tensor(t_bbox).to_tensor() + t_class.set_shape(tf.TensorShape([None, 1])) + t_class = tf.RaggedTensor.from_tensor(t_class).to_tensor() + return image, t_bbox, t_class + + +def iter_tuple_to_dict(data): + image, t_bbox, t_class = data + return { + "images": image, + "target_bbox": t_bbox, + "target_class": t_class + } def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None): """ Load a coco dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir - ann_file = config.data.ann_file if ann_file is None else ann_file - img_dir = config.data.img_dir if img_dir is None else img_dir - - + if ann_dir is None: + ann_file = config.data.ann_file if ann_file is None else os.path.join(config.data_dir, ann_file) + else: + ann_file = config.data.ann_file if ann_file is None else os.path.join(ann_dir, ann_file) + img_dir = config.data.img_dir if img_dir is None else os.path.join(config.data_dir, img_dir) coco = COCO(ann_file) @@ -106,22 +155,22 @@ def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_ # Setup the data pipeline img_ids = coco.getImgIds() - shuffle(img_ids) + + #shuffle(img_ids) dataset = tf.data.Dataset.from_tensor_slices(img_ids) # Shuffle the dataset - dataset = dataset.shuffle(1000) + #dataset = dataset.shuffle(1000) + # Retrieve img and labels - outputs_types=(tf.float32, tf.float32, tf.int64, tf.int64) + outputs_types=(tf.float32, tf.float32, tf.int64) dataset = dataset.map(lambda idx: processing.numpy_fc( idx, get_coco_from_id, outputs_types=outputs_types, coco=coco, augmentation=augmentation, config=config, img_dir=img_dir) , num_parallel_calls=tf.data.experimental.AUTOTUNE) - dataset = dataset.filter(lambda imgs, tbbox, tclass, iscrowd: tf.shape(tbbox)[0] > 0 and iscrowd != 1) - dataset = dataset.map(lambda imgs, tbbox, tclass, iscrowd: (imgs, tbbox, tclass), num_parallel_calls=tf.data.experimental.AUTOTUNE) - # Pad bbox and labels - dataset = dataset.map(processing.pad_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE) - - dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.map(tensor_to_ragged, num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size, drop_remainder=True)) dataset = dataset.prefetch(32) + + dataset.itertuple2dict = lambda data: iter_tuple_to_dict(data) return dataset, class_names \ No newline at end of file diff --git a/detr_tf/data/processing.py b/detr_tf/data/processing.py index 4629c398..be72b24c 100644 --- a/detr_tf/data/processing.py +++ b/detr_tf/data/processing.py @@ -28,8 +28,14 @@ def numpy_fc(idx, fc, outputs_types=(tf.float32, tf.float32, tf.int64), **params Call a numpy function on each given ID (`idx`) and load the associated image and labels (bbbox and cls) """ def _np_function(_idx): - return fc(_idx, **params) - return tf.numpy_function(_np_function, [idx], outputs_types) + data = fc(_idx, **params) + return data + + data = tf.numpy_function(_np_function, [idx], outputs_types) + + #data = tuple(map(lambda x : tf.RaggedTensor.from_tensor(x).to_tensor(), data)) + + return data def pad_labels(images: tf.Tensor, t_bbox: tf.Tensor, t_class: tf.Tensor): diff --git a/detr_tf/data/tfcsv.py b/detr_tf/data/tfcsv.py index 3503a56e..e2c27633 100644 --- a/detr_tf/data/tfcsv.py +++ b/detr_tf/data/tfcsv.py @@ -36,7 +36,27 @@ def load_data_from_index(index, class_names, filenames, anns, config, augmentati def load_tfcsv_dataset(config, batch_size, augmentation=False, exclude=[], ann_dir=None, ann_file=None, img_dir=None): - """ Load the hardhat dataset + """ Load a Tensorflow csv Dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + exclude: list + Exclude some class from the training. Nothing happen if empty. + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir ann_file = config.data.ann_file if ann_file is None else ann_file diff --git a/detr_tf/data/transformation.py b/detr_tf/data/transformation.py index be0e9cc8..9c2e3eac 100644 --- a/detr_tf/data/transformation.py +++ b/detr_tf/data/transformation.py @@ -2,12 +2,106 @@ import imgaug as ia import imgaug.augmenters as iaa import numpy as np +import random from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage from imgaug.augmentables.segmaps import SegmentationMapsOnImage import tensorflow as tf + +def get_size_with_aspect_ratio(w, h, size, max_size=None): + + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + +def get_multiscale_transform(images, + scales=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800], + random_crop=(384, 600), + random_resize=[400, 500, 600], + max_size=None): + """ Coco Augmentation pipeline + """ + h, w, _ = images.shape + + one_of = [] + + scale = np.random.choice(scales) + scale_height, scake_width = get_size_with_aspect_ratio(w, h, scale, max_size=max_size) + one_of.append( + iaa.Resize({"height": scale_height, "width": scake_width}) + ) + + random_resize_crop = [] + if random_resize is not None and len(random_resize) > 0: + scale = np.random.choice(random_resize) + resize_height, resize_width = get_size_with_aspect_ratio(w, h, scale) + random_resize_crop.append( + iaa.Resize({"height": resize_height, "width": resize_width}) + ) + if random_crop is not None: + crop_width = random.randint(random_crop[0], random_crop[1]) + crop_height = random.randint(random_crop[0], random_crop[1]) + random_resize_crop.append( + iaa.CropToFixedSize(crop_width, crop_height) + ) + + random_resize_crop.append( + iaa.Resize({"height": scale_height, "width": scake_width}) + ) + + one_of.append(iaa.Sequential(random_resize_crop)) + + seq = iaa.OneOf(one_of) + return seq + + + +def get_train_fixedsize_transform(image_size): + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + seq = iaa.Sequential([ + iaa.Fliplr(0.5), # horizontal flips + sometimes(iaa.OneOf([ + # Resize complety the image + iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL), + # Crop into the image + iaa.CropToFixedSize(image_size[1], image_size[0]), + # Affine transform + iaa.Affine( + scale={"x": (0.5, 1.5), "y": (0.5, 1.5)}, + ) + ])), + # Be sure to resize to the target image size + iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL) + ], random_order=False) # apply augmenters in random order + return seq + + +def get_valid_fixedsize_transform(image_size): + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + seq = iaa.Sequential([ + # Be sure to resize to the target image size + iaa.Resize({"width": image_size[1], "height": image_size[0]}) + ], random_order=False) # apply augmenters in random order + return seq + + def bbox_xcyc_wh_to_imgaug_bbox(bbox, target_class, height, width): img_aug_bbox = [] @@ -64,52 +158,23 @@ def detr_aug_seq(image, config, augmenation): max_side_max = 1333 image_size = config.image_size - if augmenation: - - seq = iaa.Sequential([ - iaa.Fliplr(0.5), # horizontal flips - sometimes(iaa.OneOf([ - # Resize complety the image - iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL), - # Crop into the image - iaa.CropToFixedSize(image_size[1], image_size[0]), - # Affine transform - iaa.Affine( - scale={"x": (0.5, 1.5), "y": (0.5, 1.5)}, - ) - ])), - # Be sure to resize to the target image size - iaa.Resize({"width": image_size[1], "height": image_size[0]}, interpolation=ia.ALL) - ], random_order=False) # apply augmenters in random order - - return seq - + # Multi scale training + if image_size is None: + if augmenation: + return get_multiscale_transform(image, max_size=1300) + else: + return get_multiscale_transform( + image, + scales=[800], + random_crop=None, + random_resize=None, + max_size=1300 + ) else: - - seq = iaa.Sequential([ - # Be sure to resize to the target image size - iaa.Resize({"width": image_size[1], "height": image_size[0]}) - ], random_order=False) # apply augmenters in random order - - return seq - - """ Mode paper evaluation - # Evaluation mode, we took the largest min side the model is trained on - target_min_side_size = 480 - image_min_side = min(float(image.shape[0]), float(image.shape[1])) - image_max_side = max(float(image.shape[0]), float(image.shape[1])) - - min_side_scaling = target_min_side_size / image_min_side - max_side_scaling = max_side_max / image_max_side - scaling = min(min_side_scaling, max_side_scaling) - - n_height = int(scaling * image.shape[0]) - n_width = int(scaling * image.shape[1]) - - seq = iaa.Sequential([ - iaa.Resize({"height": n_height, "width": n_width}), - ]) - """ + if augmenation: + return get_train_fixedsize_transform(image_size) + else: + return get_valid_fixedsize_transform(image_size) return seq diff --git a/detr_tf/data/voc.py b/detr_tf/data/voc.py index 34e88c30..8390e672 100644 --- a/detr_tf/data/voc.py +++ b/detr_tf/data/voc.py @@ -19,9 +19,9 @@ 'sheep', 'sofa', 'train', 'tvmonitor' ] -def load_voc_labels(img_id, class_names, voc_dir, augmentation, config): +def load_voc_labels(img_id, class_names, voc_dir, ann_dir, augmentation, config): - anno_path = os.path.join(voc_dir, config.data.ann_dir, img_id + '.xml') + anno_path = os.path.join(voc_dir, ann_dir, img_id + '.xml') objects = ET.parse(anno_path).findall('object') size = ET.parse(anno_path).find('size') width = float(size.find("width").text) @@ -55,13 +55,13 @@ def load_voc_labels(img_id, class_names, voc_dir, augmentation, config): return t_bbox, t_class -def load_voc_from_id(img_id, class_names, voc_dir, augmentation, config, img_dir): +def load_voc_from_id(img_id, class_names, voc_dir, ann_dir, augmentation, config, img_dir): img_id = str(img_id.decode()) # Load image - img_path = os.path.join(voc_dir, config.data.img_dir, img_id + '.jpg') + img_path = os.path.join(voc_dir, img_dir, img_id + '.jpg') image = imageio.imread(img_path) # Load labels - t_bbox, t_class = load_voc_labels(img_id, class_names, voc_dir, augmentation, config) + t_bbox, t_class = load_voc_labels(img_id, class_names, voc_dir, ann_dir, augmentation, config) # Apply augmentations if augmentation is not None: image, t_bbox, t_class = transformation.detr_transform(image, t_bbox, t_class, config, augmentation) @@ -77,7 +77,25 @@ def load_voc_from_id(img_id, class_names, voc_dir, augmentation, config, img_dir def load_voc_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None): - """ + """ Load a VOC dataset + + Parameters + ---------- + config: TrainingConfig + Instance of TrainingConfig + batch_size: int + Size of the desired batch size + augmentation: bool + Apply augmentations on the training data + ann_dir: str + Path to the coco dataset + If None, will be equal to config.data.ann_dir + ann_file: str + Path to the ann_file relative to the ann_dir + If None, will be equal to config.data.ann_file + img_dir: str + Path to the img_dir relative to the data_dir + If None, will be equal to config.data.img_dir """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir ann_file = config.data.ann_file if ann_file is None else ann_file @@ -115,7 +133,7 @@ def load_voc_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_f dataset = dataset.shuffle(1000) # Retrieve img and labels dataset = dataset.map(lambda idx: processing.numpy_fc(idx, load_voc_from_id, - class_names=class_names, voc_dir=config.data.data_dir, augmentation=augmentation, config=config, img_dir=img_dir) + class_names=class_names, voc_dir=config.data.data_dir, ann_dir=ann_dir, augmentation=augmentation, config=config, img_dir=img_dir) , num_parallel_calls=tf.data.experimental.AUTOTUNE) # Filter labels to be sure to keep only sample with at least one bbox dataset = dataset.filter(lambda imgs, tbbox, tclass: tf.shape(tbbox)[0] > 0) diff --git a/detr_tf/logger/training_logging.py b/detr_tf/logger/training_logging.py index 2291033e..c3560e29 100644 --- a/detr_tf/logger/training_logging.py +++ b/detr_tf/logger/training_logging.py @@ -21,13 +21,13 @@ RAGGED = False -def tf_send_batch_log_to_wandb(images, target_bbox, target_class, m_outputs: dict, config, class_name=[], step=None, prefix=""): +def tf_send_batch_log_to_wandb(images, target_bbox, target_class, m_outputs: dict, config, batch_size, class_name=[], step=None, prefix=""): # Warning: In graph mode, this class is init only once. In eager mode, this class is init at each step. img_sender = WandbSender() predicted_bbox = m_outputs["pred_boxes"] - for b in range(predicted_bbox.shape[0]): + for b in range(batch_size): # Select within the batch the elements at indice b image = images[b] @@ -67,17 +67,15 @@ def compute_map_on_batch(images, target_bbox, target_class, m_outputs: dict, co # Target t_bbox, t_class = target_bbox[b], target_class[b] - if not RAGGED: - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) - t_bbox = bbox.xcycwh_to_yx_min_yx_max(t_bbox) - t_class = tf.slice(t_class, [1, 0], [size, -1]) - t_class = tf.squeeze(t_class, axis=-1) + #t_class = tf.slice(t_class, [1, 0], [size, -1]) + t_bbox = bbox.xcycwh_to_yx_min_yx_max(t_bbox) + t_class = tf.squeeze(t_class, axis=-1) # Inference ops predicted_bbox, predicted_labels, predicted_scores = get_model_inference(elem_m_outputs, config.background_class, bbox_format="yxyx") pred_mask = None - + + # Fake masks (durty adapted code) pred_mask = np.zeros((138, 138, len(predicted_bbox))) target_mask = np.zeros((138, 138, len(t_bbox))) WandbSender.compute_map( @@ -89,18 +87,30 @@ def compute_map_on_batch(images, target_bbox, target_class, m_outputs: dict, co -def train_log(images, t_bbox, t_class, m_outputs: dict, config, step, class_name=[], prefix="train/"): - # Every 1000 steps, log some progress of the training +def train_log(data, m_outputs: dict, config, step, class_name=[], prefix="train/"): + # Every x steps, log some progress of the training # (Images with bbox and images logs) if step % 100 == 0: - tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=step, prefix=prefix) + tf_send_batch_log_to_wandb(data["images"], data["target_bbox"], data["target_class"], m_outputs, config, config.batch_size, class_name=class_name, step=step, prefix=prefix) -def valid_log(images, t_bbox, t_class, m_outputs: dict, config, step, global_step, class_name=[], evaluation_step=200, prefix="train/"): +def valid_log(data: dict, m_outputs: dict, config, batch_size, step, global_step, class_name=[], evaluation_step=200, prefix="train/"): # Set the number of class WandbSender.init_ap_data(nb_class=len(class_name)) - map_list = compute_map_on_batch(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, send=(step+1==evaluation_step), prefix="val/") + + # Compute AP + map_list = compute_map_on_batch( + images=data["images"], + target_bbox=data["target_bbox"], + target_class=data["target_class"], + m_outputs=m_outputs, + config=config, + class_name=class_name, + step=global_step, + send=(step+1==evaluation_step), + prefix="val/" + ) if step == 0: - tf_send_batch_log_to_wandb(images, t_bbox, t_class, m_outputs, config, class_name=class_name, step=global_step, prefix="val/") + tf_send_batch_log_to_wandb(data["images"], data["target_bbox"], data["target_class"], m_outputs, config, batch_size, class_name=class_name, step=global_step, prefix="val/") diff --git a/detr_tf/logger/wandb_logging.py b/detr_tf/logger/wandb_logging.py index b3b0d3a9..b2f10b30 100644 --- a/detr_tf/logger/wandb_logging.py +++ b/detr_tf/logger/wandb_logging.py @@ -118,7 +118,7 @@ def compute_map(p_bbox: np.array, p_labels: np.array, p_scores: np.array, t_bbox except Exception as e: print("compute_map error. e=", e) - #raise e + raise e return np.array([0.0, 0.0], np.float64) return np.array([0.0, 0.0], np.float64) diff --git a/detr_tf/loss/compute_map.py b/detr_tf/loss/compute_map.py index fd7412c5..6b5027ec 100644 --- a/detr_tf/loss/compute_map.py +++ b/detr_tf/loss/compute_map.py @@ -182,13 +182,6 @@ def print_maps(all_maps): def cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, gt_classes, t_mask, ap_data, iou_thresholds): - #print("p_bbox", p_bbox.shape) - #print("p_labels", p_labels.shape) - #print("p_scores", p_scores.shape) - #print("p_mask", p_mask.shape) - #print("t_bbox", t_bbox.shape) - #print("gt_classes", gt_classes) - #print("t_mask", t_mask.shape) num_crowd = 0 @@ -220,8 +213,7 @@ def cal_map(p_bbox, p_labels, p_scores, p_mask, t_bbox, gt_classes, t_mask, ap_d lambda i,j: crowd_mask_iou_cache[i,j].item(), lambda i: mask_scores[i], mask_indices) ] - #print("run", list(classes), list(gt_classes)) - #print(classes + gt_classes) + for _class in set(list(classes) + list(gt_classes)): ap_per_iou = [] num_gt_for_class = sum([1 for x in gt_classes if x == _class]) diff --git a/detr_tf/loss/hungarian_matching.py b/detr_tf/loss/hungarian_matching.py index 35658307..301b4fc2 100644 --- a/detr_tf/loss/hungarian_matching.py +++ b/detr_tf/loss/hungarian_matching.py @@ -162,11 +162,13 @@ def loss_boxes(outputs, targets, indices, num_boxes): def hungarian_matching(t_bbox, t_class, p_bbox, p_class, fcost_class=1, fcost_bbox=5, fcost_giou=2, slice_preds=True) -> tuple: - if slice_preds: - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) - t_class = tf.slice(t_class, [1, 0], [size, -1]) - t_class = tf.squeeze(t_class, axis=-1) + t_class = tf.squeeze(t_class, axis=-1) + _filter = tf.squeeze(tf.where(t_class != -1), axis=-1) + #print("t_class", t_class.shape) + t_class = tf.gather(t_class, _filter) + #print('t_class', t_class.shape) + t_bbox = tf.gather(t_bbox, _filter) + #print('t_bbox', t_bbox.shape) # Convert frpm [xc, yc, w, h] to [xmin, ymin, xmax, ymax] p_bbox_xy = bbox.xcycwh_to_xy_min_xy_max(p_bbox) diff --git a/detr_tf/loss/loss.py b/detr_tf/loss/loss.py index 2d87e488..1cd4eee9 100644 --- a/detr_tf/loss/loss.py +++ b/detr_tf/loss/loss.py @@ -19,13 +19,13 @@ def get_total_losss(losses): return total_loss -def get_losses(m_outputs, t_bbox, t_class, config): - losses = get_detr_losses(m_outputs, t_bbox, t_class, config) +def get_losses(m_outputs, t_bbox, t_class, config, batch_size): + losses = get_detr_losses(m_outputs, t_bbox, t_class, config, batch_size) # Get auxiliary loss for each auxiliary output if "aux" in m_outputs: for a, aux_m_outputs in enumerate(m_outputs["aux"]): - aux_losses = get_detr_losses(aux_m_outputs, t_bbox, t_class, config, suffix="_{}".format(a)) + aux_losses = get_detr_losses(aux_m_outputs, t_bbox, t_class, config, batch_size, suffix="_{}".format(a)) losses.update(aux_losses) # Compute the total loss @@ -95,7 +95,7 @@ def loss_boxes(p_bbox, p_class, t_bbox, t_class, t_indices, p_indices, t_selecto return loss_giou, l1_loss -def get_detr_losses(m_outputs, target_bbox, target_label, config, suffix=""): +def get_detr_losses(m_outputs, target_bbox, target_label, config, batch_size, suffix=""): predicted_bbox = m_outputs["pred_boxes"] predicted_label = m_outputs["pred_logits"] @@ -112,9 +112,10 @@ def get_detr_losses(m_outputs, target_bbox, target_label, config, suffix=""): t_offset = 0 p_offset = 0 - for b in range(predicted_bbox.shape[0]): + for b in range(batch_size): p_bbox, p_class, t_bbox, t_class = predicted_bbox[b], predicted_label[b], target_bbox[b], target_label[b] + t_indices, p_indices, t_selector, p_selector, t_bbox, t_class = hungarian_matching(t_bbox, t_class, p_bbox, p_class, slice_preds=True) t_indices = t_indices + tf.cast(t_offset, tf.int64) diff --git a/detr_tf/networks/detr.py b/detr_tf/networks/detr.py index 7f2a202c..4a52529b 100644 --- a/detr_tf/networks/detr.py +++ b/detr_tf/networks/detr.py @@ -100,7 +100,7 @@ def add_heads_nlayers(config, detr, nb_class): tf.keras.layers.Dense(256, activation="relu"), tf.keras.layers.Dense(4, activation="sigmoid"), ], name="pos_layer") - config.add_nlayers([cls_layer, pos_layer]) + config.add_heads([cls_layer, pos_layer]) transformer_output = detr(image_input) cls_preds = cls_layer(transformer_output) @@ -139,6 +139,7 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba load_weights(detr, weights) image_input = tf.keras.Input((None, None, 3)) + image_mask = tf.keras.Input((None, None, 1)) # Backbone if not tf_backbone: @@ -169,18 +170,21 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba x = backbone(image_input) - masks = tf.zeros((tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]), tf.bool) + # Resize the mask to the same size of the backbone outptu + masks = tf.image.resize(image_mask, (tf.shape(x)[1], tf.shape(x)[2]), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + masks = tf.cast(masks, tf.int32) + pos_encoding = position_embedding_sine(masks) hs = transformer(input_proj(x), masks, query_embed(None), pos_encoding)[0] - detr = tf.keras.Model(image_input, hs, name="detr") + detr = tf.keras.Model([image_input, image_mask], hs, name="detr") if include_top is False and nb_class is None: return detr elif include_top is False and nb_class is not None: return add_heads_nlayers(config, detr, nb_class) - transformer_output = detr(image_input) + transformer_output = detr((image_input, masks)) outputs_class = class_embed(transformer_output) box_ftmps = activation(bbox_embed_linear1(transformer_output)) @@ -201,5 +205,5 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba "pred_boxes": pred_boxes }) - return tf.keras.Model(image_input, output, name="detr_finetuning") + return tf.keras.Model([image_input, image_mask], output, name="detr_finetuning") diff --git a/detr_tf/networks/position_embeddings.py b/detr_tf/networks/position_embeddings.py index c567c380..bd9acb38 100644 --- a/detr_tf/networks/position_embeddings.py +++ b/detr_tf/networks/position_embeddings.py @@ -1,12 +1,10 @@ import numpy as np import tensorflow as tf - -class PositionEmbeddingSine(tf.keras.Model): - - +class PositionEmbeddingSine(tf.keras.layers.Layer): + # These are the default parameters used in the original project def __init__(self, num_pos_features=64, temperature=10000, - normalize=False, scale=None, eps=1e-6, **kwargs): + normalize=False, scale=None, eps=1e-6, center=False, **kwargs): super().__init__(**kwargs) self.num_pos_features = num_pos_features @@ -18,33 +16,63 @@ def __init__(self, num_pos_features=64, temperature=10000, scale = 2 * np.pi self.scale = scale self.eps = eps + self.center = center + def call(self, mask): - not_mask = tf.cast(~mask, tf.float32) - y_embed = tf.math.cumsum(not_mask, axis=1) - x_embed = tf.math.cumsum(not_mask, axis=2) + + not_mask = tf.cast(mask == 0, tf.float32) + + y_embed_mask = tf.cumsum(not_mask, axis=1) + x_embed_mask = tf.cumsum(not_mask, axis=2) + y_embed_mask = tf.squeeze(y_embed_mask, axis=-1) + x_embed_mask = tf.squeeze(x_embed_mask, axis=-1) + + #print("y_embed_mask", y_embed_mask.shape) + #print("x_embed_mask", x_embed_mask.shape) + + x = tf.range(tf.shape(mask)[2]) + 1 + y = tf.range(tf.shape(mask)[1]) + 1 + x_embed, y_embed = tf.meshgrid(x, y) + + x_embed = tf.expand_dims(x_embed, axis=0) + y_embed = tf.expand_dims(y_embed, axis=0) + + x_embed = tf.tile(x_embed, [tf.shape(mask)[0], 1, 1,]) + y_embed = tf.tile(y_embed, [tf.shape(mask)[0], 1, 1,]) + x_embed = tf.cast(x_embed, tf.float32) + y_embed = tf.cast(y_embed, tf.float32) + + #print('x_embed', x_embed.shape) + #print("y_embed", y_embed.shape) if self.normalize: - y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + if self.center: + y_embed = y_embed-0.5 + x_embed = x_embed-0.5 + y_embed = y_embed / (y_embed_mask[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed_mask[:, :, -1:] + self.eps) * self.scale dim_t = tf.range(self.num_pos_features, dtype=tf.float32) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_features) - pos_x = x_embed[..., tf.newaxis] / dim_t - pos_y = y_embed[..., tf.newaxis] / dim_t - + x_embed = tf.expand_dims(x_embed, axis=-1) + y_embed = tf.expand_dims(y_embed, axis=-1) + + pos_x = x_embed / dim_t + pos_y = y_embed / dim_t + pos_x = tf.stack([tf.math.sin(pos_x[..., 0::2]), tf.math.cos(pos_x[..., 1::2])], axis=4) pos_y = tf.stack([tf.math.sin(pos_y[..., 0::2]), tf.math.cos(pos_y[..., 1::2])], axis=4) - shape = [tf.shape(pos_x)[i] for i in range(3)] + [-1] pos_x = tf.reshape(pos_x, shape) pos_y = tf.reshape(pos_y, shape) pos_emb = tf.concat([pos_y, pos_x], axis=3) + return pos_emb diff --git a/detr_tf/networks/transformer.py b/detr_tf/networks/transformer.py index 08f34448..60c402a5 100644 --- a/detr_tf/networks/transformer.py +++ b/detr_tf/networks/transformer.py @@ -269,19 +269,6 @@ def build(self, input_shapes): - - #self.in_proj_weight = tf.Variable( - # tf.zeros((in_dim, self.model_dim), dtype=tf.float32), name='in_proj_kernel') - #self.in_proj_bias = tf.Variable(tf.zeros((in_dim,), dtype=tf.float32), - # name='in_proj_bias') - - #self.out_proj_weight = tf.Variable( - # tf.zeros((self.model_dim, self.model_dim), dtype=tf.float32), name='out_proj_kernel') - #self.out_proj_bias = tf.Variable( - # tf.zeros((self.model_dim,), dtype=tf.float32), name='out_proj_bias') - - - def call(self, inputs, attn_mask=None, key_padding_mask=None, need_weights=True, training=False): @@ -319,8 +306,10 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None, if attn_mask is not None: attn_output_weights += attn_mask - """ + if key_padding_mask is not None: + key_padding_mask = tf.cast(key_padding_mask, tf.bool) + attn_output_weights = tf.reshape(attn_output_weights, [batch_size, self.num_heads, target_len, source_len]) @@ -328,13 +317,13 @@ def call(self, inputs, attn_mask=None, key_padding_mask=None, key_padding_mask = tf.expand_dims(key_padding_mask, 2) key_padding_mask = tf.tile(key_padding_mask, [1, self.num_heads, target_len, 1]) + #print("before attn_output_weights", attn_output_weights.shape) attn_output_weights = tf.where(key_padding_mask, tf.zeros_like(attn_output_weights) + float('-inf'), attn_output_weights) attn_output_weights = tf.reshape(attn_output_weights, [batch_size * self.num_heads, target_len, source_len]) - """ attn_output_weights = tf.nn.softmax(attn_output_weights, axis=-1) diff --git a/detr_tf/optimizers.py b/detr_tf/optimizers.py index ae1de914..c7f0a310 100644 --- a/detr_tf/optimizers.py +++ b/detr_tf/optimizers.py @@ -1,4 +1,5 @@ import tensorflow as tf +import tensorflow_addons as tfa def disable_batchnorm_training(model): for l in model.layers: @@ -7,20 +8,6 @@ def disable_batchnorm_training(model): elif isinstance(l, tf.keras.layers.BatchNormalization): l.trainable = False -def get_transformers_trainable_variables(model, exclude=[]): - transformers_variables = [] - - # Transformers variables - transformers_variables = model.get_layer("detr").get_layer("transformer").trainable_variables - - for layer in model.layers[2:]: - if layer.name not in exclude: - transformers_variables += layer.trainable_variables - else: - pass - - return transformers_variables - def get_backbone_trainable_variables(model): backbone_variables = [] @@ -36,11 +23,26 @@ def get_backbone_trainable_variables(model): return backbone_variables -def get_nlayers_trainables_variables(model, nlayers_names): - nlayers_variables = [] - for nlayer_name in nlayers_names: - nlayers_variables += model.get_layer(nlayer_name).trainable_variables - return nlayers_variables +def get_transformers_trainable_variables(model, exclude=[]): + transformers_variables = [] + + # Transformers variables + transformers_variables = model.get_layer("detr").get_layer("transformer").trainable_variables + + for layer in model.layers[2:]: + if layer.name not in exclude: + transformers_variables += layer.trainable_variables + else: + pass + + return transformers_variables + + +def get_heads_trainables_variables(model, heads_names): + heads_variables = [] + for nlayer_name in heads_names: + heads_variables += model.get_layer(nlayer_name).trainable_variables + return heads_variables def get_trainable_variables(model, config): @@ -49,19 +51,14 @@ def get_trainable_variables(model, config): backbone_variables = [] transformers_variables = [] - nlayers_variables = [] - + heads_variables = [] - # Retrieve the gradient ofr each trainable variables - #if config.train_backbone: + # The gradient will be retrieve for each trainable variable backbone_variables = get_backbone_trainable_variables(model) - #if config.train_transformers: - transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers) - #if config.train_nlayers: - nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers) + transformers_variables = get_transformers_trainable_variables(model, exclude=config.heads) + heads_variables = get_heads_trainables_variables(model, config.heads) - - return backbone_variables, transformers_variables, nlayers_variables + return backbone_variables, transformers_variables, heads_variables def setup_optimizers(model, config): @@ -76,59 +73,79 @@ def get_transformers_learning_rate(): return config.transformers_lr @tf.function - def get_nlayers_learning_rate(): - return config.nlayers_lr + def get_heads_learning_rate(): + return config.heads_lr + + @tf.function + def get_backbone_wd(): + return config.backbone_lr*config.backbone_wd + + @tf.function + def get_transformers_wd(): + return config.transformers_lr*config.transformers_wd + + @tf.function + def get_heads_wd(): + return config.heads_lr*config.heads_wd + + # Disable batch norm on the backbone disable_batchnorm_training(model) # Optimizers - backbone_optimizer = tf.keras.optimizers.Adam(learning_rate=get_backbone_learning_rate, clipnorm=config.gradient_norm_clipping) - transformers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_transformers_learning_rate, clipnorm=config.gradient_norm_clipping) - nlayers_optimizer = tf.keras.optimizers.Adam(learning_rate=get_nlayers_learning_rate, clipnorm=config.gradient_norm_clipping) + backbone_optimizer = tfa.optimizers.AdamW( + learning_rate=get_backbone_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_backbone_wd + ) + transformers_optimizer = tfa.optimizers.AdamW( + learning_rate=get_transformers_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_transformers_wd + ) + heads_optimizer = tfa.optimizers.AdamW( + learning_rate=get_heads_learning_rate, clipnorm=config.gradient_norm_clipping, weight_decay=get_heads_wd + ) # Set trainable variables - backbone_variables, transformers_variables, nlayers_variables = [], [], [] + backbone_variables, transformers_variables, heads_variables = [], [], [] backbone_variables = get_backbone_trainable_variables(model) - transformers_variables = get_transformers_trainable_variables(model, exclude=config.nlayers) - nlayers_variables = get_nlayers_trainables_variables(model, config.nlayers) + transformers_variables = get_transformers_trainable_variables(model, exclude=config.heads) + heads_variables = get_heads_trainables_variables(model, config.heads) return { "backbone_optimizer": backbone_optimizer, "transformers_optimizer": transformers_optimizer, - "nlayers_optimizer": nlayers_optimizer, + "heads_optimizer": heads_optimizer, "backbone_variables": backbone_variables, "transformers_variables": transformers_variables, - "nlayers_variables": nlayers_variables, + "heads_variables": heads_variables, } def gather_gradient(model, optimizers, total_loss, tape, config, log): - backbone_variables, transformers_variables, nlayers_variables = get_trainable_variables(model, config) - trainables_variables = backbone_variables + transformers_variables + nlayers_variables + backbone_variables, transformers_variables, heads_variables = get_trainable_variables(model, config) + trainables_variables = backbone_variables + transformers_variables + heads_variables gradients = tape.gradient(total_loss, trainables_variables) # Retrieve the gradients from the tap backbone_gradients = gradients[:len(optimizers["backbone_variables"])] transformers_gradients = gradients[len(optimizers["backbone_variables"]):len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"])] - nlayers_gradients = gradients[len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"]):] + heads_gradients = gradients[len(optimizers["backbone_variables"])+len(optimizers["transformers_variables"]):] gradient_steps = {} gradient_steps["backbone"] = {"gradients": backbone_gradients} gradient_steps["transformers"] = {"gradients": transformers_gradients} - gradient_steps["nlayers"] = {"gradients": nlayers_gradients} + gradient_steps["heads"] = {"gradients": heads_gradients} log.update({"backbone_lr": optimizers["backbone_optimizer"]._serialize_hyperparameter("learning_rate")}) log.update({"transformers_lr": optimizers["transformers_optimizer"]._serialize_hyperparameter("learning_rate")}) - log.update({"nlayers_lr": optimizers["nlayers_optimizer"]._serialize_hyperparameter("learning_rate")}) + log.update({"heads_lr": optimizers["heads_optimizer"]._serialize_hyperparameter("learning_rate")}) return gradient_steps diff --git a/detr_tf/training.py b/detr_tf/training.py index edd23b22..f530403f 100644 --- a/detr_tf/training.py +++ b/detr_tf/training.py @@ -1,13 +1,38 @@ import tensorflow as tf +import matplotlib.pyplot as plt + from .optimizers import gather_gradient, aggregate_grad_and_apply from .logger.training_logging import valid_log, train_log from .loss.loss import get_losses import time import wandb + +def handle_data(data): + """ Create (TODO) a mask from the given ragged images. The mask will be use in the encoder/decode + attention. + """ + n_data = {} + for key in data: + n_data[key] = data[key] + + padding_mask = tf.ones_like(data["images"]) + # The following will add 0 on all padded part + padding_mask = padding_mask.to_tensor()[:,:,:,:1] + # Set one instead of zero on all paded part + padding_mask = tf.abs(padding_mask - 1) + + n_data["images"] = n_data["images"].to_tensor() + n_data["mask"] = padding_mask + + return n_data + + @tf.function -def run_train_step(model, images, t_bbox, t_class, optimizers, config): +def run_train_step(model, data, optimizers, config): + + n_data = handle_data(data) if config.target_batch is not None: gradient_aggregate = int(config.target_batch // config.batch_size) @@ -15,8 +40,8 @@ def run_train_step(model, images, t_bbox, t_class, optimizers, config): gradient_aggregate = 1 with tf.GradientTape() as tape: - m_outputs = model(images, training=True) - total_loss, log = get_losses(m_outputs, t_bbox, t_class, config) + m_outputs = model((n_data["images"], n_data["mask"]), training=True) + total_loss, log = get_losses(m_outputs, t_bbox=n_data["target_bbox"], t_class=n_data["target_class"], config=config, batch_size=config.batch_size) total_loss = total_loss / gradient_aggregate # Compute gradient for each part of the network @@ -26,9 +51,12 @@ def run_train_step(model, images, t_bbox, t_class, optimizers, config): @tf.function -def run_val_step(model, images, t_bbox, t_class, config): - m_outputs = model(images, training=False) - total_loss, log = get_losses(m_outputs, t_bbox, t_class, config) +def run_val_step(model, data, config, batch_size): + + n_data = handle_data(data) + + m_outputs = model((n_data["images"], n_data["mask"]), training=False) + total_loss, log = get_losses(m_outputs, t_bbox=n_data["target_bbox"], t_class=n_data["target_class"], config=config, batch_size=batch_size) return m_outputs, total_loss, log @@ -40,14 +68,16 @@ def fit(model, train_dt, optimizers, config, epoch_nb, class_names): if config.target_batch is not None: gradient_aggregate = int(config.target_batch // config.batch_size) t = None - for epoch_step , (images, t_bbox, t_class) in enumerate(train_dt): + for epoch_step , data in enumerate(train_dt): + + data = train_dt.itertuple2dict(data) # Run the prediction and retrieve the gradient step for each part of the network - m_outputs, total_loss, log, gradient_steps = run_train_step(model, images, t_bbox, t_class, optimizers, config) + m_outputs, total_loss, log, gradient_steps = run_train_step(model, data, optimizers, config) # Load the predictions if config.log: - train_log(images, t_bbox, t_class, m_outputs, config, config.global_step, class_names, prefix="train/") + train_log(handle_data(data), m_outputs, config, config.global_step, class_names, prefix="train/") # Aggregate and apply the gradient for name in gradient_steps: @@ -65,16 +95,18 @@ def fit(model, train_dt, optimizers, config, epoch_nb, class_names): config.global_step += 1 -def eval(model, valid_dt, config, class_name, evaluation_step=200): +def eval(model, valid_dt, config, class_name, evaluation_step=200, batch_size=None): """ Evaluate the model on the validation set """ + batch_size = config.batch_size if batch_size is None else batch_size t = None - for val_step, (images, t_bbox, t_class) in enumerate(valid_dt): + for val_step, data in enumerate(valid_dt): + data = valid_dt.itertuple2dict(data) # Run prediction - m_outputs, total_loss, log = run_val_step(model, images, t_bbox, t_class, config) + m_outputs, total_loss, log = run_val_step(model, data, config, batch_size) # Log the predictions if config.log: - valid_log(images, t_bbox, t_class, m_outputs, config, val_step, config.global_step, class_name, evaluation_step=evaluation_step, prefix="train/") + valid_log(handle_data(data), m_outputs, config, batch_size, val_step, config.global_step, class_name, evaluation_step=evaluation_step, prefix="train/") # Log the metrics if config.log and val_step == 0: wandb.log({f"val/{k}":log[k] for k in log}, step=config.global_step) diff --git a/detr_tf/training_config.py b/detr_tf/training_config.py index 4f884d7b..c3b254d2 100644 --- a/detr_tf/training_config.py +++ b/detr_tf/training_config.py @@ -19,18 +19,24 @@ def training_config_parser(): # What to train parser.add_argument("--train_backbone", action='store_true', required=False, default=False, help="Train backbone") parser.add_argument("--train_transformers", action='store_true', required=False, default=False, help="Train transformers") - parser.add_argument("--train_nlayers", action='store_true', required=False, default=False, help="Train new layers") + parser.add_argument("--train_heads", action='store_true', required=False, default=False, help="Train the model heads (For finetuning)") # How to train + parser.add_argument("--image_size", default=None, required=False, type=str) parser.add_argument("--finetuning", default=False, required=False, action='store_true', help="Load the model weight before to train") parser.add_argument("--batch_size", type=int, required=False, default=1, help="Batch size to use to train the model") parser.add_argument("--gradient_norm_clipping", type=float, required=False, default=0.1, help="Gradient norm clipping") parser.add_argument("--target_batch", type=int, required=False, default=None, help="When running on a single GPU, aggretate the gradient before to apply.") # Learning rate - parser.add_argument("--backbone_lr", type=bool, required=False, default=1e-5, help="Train backbone") - parser.add_argument("--transformers_lr", type=bool, required=False, default=1e-4, help="Train transformers") - parser.add_argument("--nlayers_lr", type=bool, required=False, default=1e-4, help="Train new layers") + parser.add_argument("--backbone_lr", type=float, required=False, default=1e-5, help="Backbone learning rate") + parser.add_argument("--transformers_lr", type=float, required=False, default=1e-4, help="Transformer learning rate") + parser.add_argument("--heads_lr", type=float, required=False, default=1e-4, help="Model heads learning rate") + + # Weight decay + parser.add_argument("--backbone_wd", type=float, required=False, default=1e-4, help="Backbone weight decay") + parser.add_argument("--transformers_wd", type=float, required=False, default=1e-4, help="Transformer weight decay") + parser.add_argument("--heads_wd", type=float, required=False, default=1e-4, help="Model heads weight decay") # Logging parser.add_argument("--log", required=False, action="store_true", default=False, help="Log into wandb") @@ -46,12 +52,16 @@ def __init__(self): self.data_dir, self.img_dir, self.ann_dir, self.ann_file = None, None, None, None self.data = DataConfig(data_dir=None, img_dir=None, ann_file=None, ann_dir=None) self.background_class = 0 - self.image_size = 376, 672 + + #self.image_size = 376, 672 + # If image size is None, then multi scale training will be used as + # described in the paper. + self.image_size = None # What to train self.train_backbone = False self.train_transformers = False - self.train_nlayers = False + self.train_heads = False # How to train self.finetuning = False @@ -65,8 +75,17 @@ def __init__(self): # keeping the same graph self.backbone_lr = tf.Variable(1e-5) self.transformers_lr = tf.Variable(1e-4) - self.nlayers_lr = tf.Variable(1e-4) - self.nlayers = [] + self.heads_lr = tf.Variable(1e-4) + + # Weidht decay + # Set as tf.Variable so that the variable can be update during the training while + # keeping the same graph + self.backbone_wd = tf.Variable(1e-4) + self.transformers_wd = tf.Variable(1e-4) + self.heads_wd = tf.Variable(1e-4) + + # Heads layer list + self.heads = [] # Training progress self.global_step = 0 @@ -76,10 +95,10 @@ def __init__(self): self.normalized_method = "torch_resnet" - def add_nlayers(self, layers): + def add_heads(self, layers): """ Set the new layers to train on the training config """ - self.nlayers = [l.name for l in layers] + self.heads = [l.name for l in layers] def update_from_args(self, args): @@ -92,9 +111,11 @@ def update_from_args(self, args): else: setattr(self, key, args[key]) - # Set the config on the data class - + if self.image_size is not None: + img_size = self.image_size.split(",") + self.image_size = (int(img_size[0]), int(img_size[1])) + # Set the config on the data class self.data = DataConfig( data_dir=self.data_dir, img_dir=self.img_dir, diff --git a/eval.py b/eval.py index 98c6189b..8d5718ff 100644 --- a/eval.py +++ b/eval.py @@ -14,6 +14,11 @@ from detr_tf.bbox import xcycwh_to_xy_min_xy_max, xcycwh_to_yx_min_yx_max from detr_tf.inference import numpy_bbox_to_image from detr_tf.training_config import TrainingConfig, training_config_parser +from detr_tf.training import handle_data + +tf.random.set_seed(40) +np.random.seed(40) + def build_model(config): @@ -27,6 +32,12 @@ def build_model(config): return detr +#@tf.function +def run_model(data, model): + n_data = handle_data(data) + return model((n_data["images"], n_data["mask"])) + + def eval_model(model, config, class_names, valid_dt): """ Run evaluation """ @@ -38,24 +49,32 @@ def eval_model(model, config, class_names, valid_dt): } it = 0 - for images, target_bbox, target_class in valid_dt: + for data in valid_dt: + data = valid_dt.itertuple2dict(data) + # Forward pass - m_outputs = model(images) + m_outputs = run_model(data, model) + # Run predictions p_bbox, p_labels, p_scores = get_model_inference(m_outputs, config.background_class, bbox_format="yxyx") + # Remove padding - t_bbox, t_class = target_bbox[0], target_class[0] - size = tf.cast(t_bbox[0][0], tf.int32) - t_bbox = tf.slice(t_bbox, [1, 0], [size, 4]) + t_bbox, t_class = data["target_bbox"][0], data["target_class"][0] + t_bbox = xcycwh_to_yx_min_yx_max(t_bbox) - t_class = tf.slice(t_class, [1, 0], [size, -1]) t_class = tf.squeeze(t_class, axis=-1) + + # Filter undesired target + _filter = tf.squeeze(tf.where(t_class != -1), axis=-1) + t_class = tf.gather(t_class, _filter) + t_bbox = tf.gather(t_bbox, _filter) + # Compute map cal_map(p_bbox, p_labels, p_scores, np.zeros((138, 138, len(p_bbox))), np.array(t_bbox), np.array(t_class), np.zeros((138, 138, len(t_bbox))), ap_data, iou_thresholds) print(f"Computing map.....{it}", end="\r") it += 1 - #if it > 10: - # break + if it > 10: + break # Compute the mAp over all thresholds calc_map(ap_data, iou_thresholds, class_names, print_result=True) @@ -73,7 +92,7 @@ def eval_model(model, config, class_names, valid_dt): # Load the model with the new layers to finetune detr = build_model(config) - valid_dt, class_names = load_coco_dataset(config, 1, augmentation=None) + valid_dt, class_names = load_coco_dataset(config, 2, augmentation=False) # Run training eval_model(detr, config, class_names, valid_dt) diff --git a/finetune_coco.py b/finetune_coco.py index e1945918..44626e70 100644 --- a/finetune_coco.py +++ b/finetune_coco.py @@ -30,7 +30,7 @@ def build_model(config): """ Build the model with the pretrained weights. In this example we do not add new layers since the pretrained model is already trained on coco. - See examples/finetuning_voc.py to add new layers. + See the finetuning_voc.py script see an example on how to change the number of class on the last layer. """ # Load the pretrained model detr = get_detr_model(config, include_top=True, weights="detr") @@ -44,8 +44,10 @@ def run_finetuning(config): detr = build_model(config) # Load the training and validation dataset - train_dt, coco_class_names = load_coco_dataset("train", config.batch_size, config, augmentation=True) - valid_dt, _ = load_coco_dataset("val", 1, config, augmentation=False) + train_dt, coco_class_names = load_coco_dataset( + config, config.batch_size, augmentation=True, img_dir="val2017", ann_file="annotations/instances_val2017.json") + valid_dt, _ = load_coco_dataset( + config, 1, augmentation=False, img_dir="val2017", ann_file="annotations/instances_val2017.json") # Train/finetune the transformers only config.train_backbone = False @@ -56,7 +58,7 @@ def run_finetuning(config): # Run the training for 5 epochs for epoch_nb in range(100): - training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) + training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=100, batch_size=1) training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names) diff --git a/finetune_voc.py b/finetune_voc.py index 9bf56674..cccb1f99 100644 --- a/finetune_voc.py +++ b/finetune_voc.py @@ -31,14 +31,16 @@ ] def build_model(config): - """ Build the model with the pretrained weights - and add new layers to finetune + """ Build the model with the pretrained weights. + We set include_top to False to not include the last layer of the transformer. + Then, `nb_class` is used to automaticly replace the lat layers by new layers with the + appropriate number of target class. """ # Input image_input = tf.keras.Input((None, None, 3)) - - # Load the pretrained model - detr = get_detr_model(config, include_top=False, weights="detr", num_decoder_layers=6, num_encoder_layers=6) + # Load the pretrained model and replace the laster layers for this new task. + detr = get_detr_model(config, include_top=False, weights="detr", nb_class=len(VOC_CLASS_NAME)+1) + return detr # Setup the new layers cls_layer = tf.keras.layers.Dense(len(VOC_CLASS_NAME) + 1, name="cls_layer") @@ -68,7 +70,7 @@ def run_finetuning(config): detr = build_model(config) # Load the training and validation dataset (for the purpose of this example we're gonna load the training - # as the validation, but in practise you should have different folder loader for the training and the validation) + # as the validation, but in practise you should have different folder and loader for the training and the validation) train_dt, class_names = load_voc_dataset(config, config.batch_size, augmentation=True) valid_dt, _ = load_voc_dataset(config, 1, augmentation=False) @@ -95,8 +97,8 @@ def run_finetuning(config): config.transformers_lr.assign(1e-4) config.nlayers_lr.assign(1e-3) - training.eval(detr, valid_dt, config, class_names, evaluation_step=200) training.fit(detr, train_dt, optimzers, config, epoch_nb, class_names) + training.eval(detr, valid_dt, config, class_names, evaluation_step=200) if __name__ == "__main__": diff --git a/train_coco.py b/train_coco.py index c4ac578c..b0d97818 100644 --- a/train_coco.py +++ b/train_coco.py @@ -48,9 +48,9 @@ def run_finetuning(config): # Load the training and validation dataset train_dt, coco_class_names = load_coco_dataset( - config, config.batch_size, augmentation=True, img_dir="train2017", ann_fil="annotations/instances_train2017.json") + config, config.batch_size, augmentation=True, img_dir="val2017", ann_file="annotations/instances_val2017.json") valid_dt, _ = load_coco_dataset( - config, 1, augmentation=False, img_dir="val2017", ann_fil="annotations/instances_val2017.json") + config, 1, augmentation=False, img_dir="val2017", ann_file="annotations/instances_val2017.json") # Train the backbone and the transformers # Check the training_config file for the other hyperparameters @@ -62,8 +62,8 @@ def run_finetuning(config): # Run the training for 100 epochs for epoch_nb in range(100): - training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) training.fit(detr, train_dt, optimzers, config, epoch_nb, coco_class_names) + #training.eval(detr, valid_dt, config, coco_class_names, evaluation_step=200) if __name__ == "__main__": From 7e27b7affbac48d3abbb2bf03776e77c4c787f86 Mon Sep 17 00:00:00 2001 From: Thibault Date: Tue, 15 Jun 2021 09:29:28 +0000 Subject: [PATCH 2/5] Deformable --- detr_tf/data/coco.py | 12 ++++++++---- eval.py | 6 +++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/detr_tf/data/coco.py b/detr_tf/data/coco.py index d58d38ce..938b776e 100644 --- a/detr_tf/data/coco.py +++ b/detr_tf/data/coco.py @@ -112,7 +112,7 @@ def iter_tuple_to_dict(data): } -def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None): +def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None, shuffle=True): """ Load a coco dataset Parameters @@ -132,6 +132,8 @@ def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_ img_dir: str Path to the img_dir relative to the data_dir If None, will be equal to config.data.img_dir + shuffle : bool + Shuffle the dataset by default """ ann_dir = config.data.ann_dir if ann_dir is None else ann_dir if ann_dir is None: @@ -155,11 +157,13 @@ def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_ # Setup the data pipeline img_ids = coco.getImgIds() - - #shuffle(img_ids) + + if shuffle: + shuffle(img_ids) dataset = tf.data.Dataset.from_tensor_slices(img_ids) # Shuffle the dataset - #dataset = dataset.shuffle(1000) + if shuffle: + dataset = dataset.shuffle(1000) # Retrieve img and labels outputs_types=(tf.float32, tf.float32, tf.int64) diff --git a/eval.py b/eval.py index 8d5718ff..899372cf 100644 --- a/eval.py +++ b/eval.py @@ -73,8 +73,8 @@ def eval_model(model, config, class_names, valid_dt): cal_map(p_bbox, p_labels, p_scores, np.zeros((138, 138, len(p_bbox))), np.array(t_bbox), np.array(t_class), np.zeros((138, 138, len(t_bbox))), ap_data, iou_thresholds) print(f"Computing map.....{it}", end="\r") it += 1 - if it > 10: - break + #if it > 10: + # break # Compute the mAp over all thresholds calc_map(ap_data, iou_thresholds, class_names, print_result=True) @@ -92,7 +92,7 @@ def eval_model(model, config, class_names, valid_dt): # Load the model with the new layers to finetune detr = build_model(config) - valid_dt, class_names = load_coco_dataset(config, 2, augmentation=False) + valid_dt, class_names = load_coco_dataset(config, 1, augmentation=False, shuffle=False) # Run training eval_model(detr, config, class_names, valid_dt) From 785e26fcedaf3f8b02c63a5309867775a56a951e Mon Sep 17 00:00:00 2001 From: Thibault Date: Tue, 15 Jun 2021 11:30:42 +0200 Subject: [PATCH 3/5] Deformable DETR Inference --- detr_tf/custom_ops/ms_deform_attn/__init__.py | 25 + detr_tf/custom_ops/ms_deform_attn/build.sh | 7 + .../ms_deform_attn/ms_deform_attn.py | 187 +++ .../ms_deform_attn/ms_deform_im2col.cc | 226 +++ .../ms_deform_attn/ms_deform_im2col.cu.cc | 1353 +++++++++++++++++ .../ms_deform_attn/ms_deform_im2col.o | Bin 0 -> 223032 bytes detr_tf/custom_ops/ms_deform_attn/test.py | 206 +++ detr_tf/inference.py | 34 +- detr_tf/networks/custom_layers.py | 44 +- detr_tf/networks/deformable_detr.py | 343 +++++ detr_tf/networks/deformable_transformer.py | 583 +++++++ detr_tf/networks/detr.py | 2 +- detr_tf/networks/resnet_backbone.py | 10 +- detr_tf/networks/weights.py | 10 + webcam_inference.py | 29 +- 15 files changed, 3030 insertions(+), 29 deletions(-) create mode 100644 detr_tf/custom_ops/ms_deform_attn/__init__.py create mode 100755 detr_tf/custom_ops/ms_deform_attn/build.sh create mode 100644 detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py create mode 100644 detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc create mode 100644 detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc create mode 100644 detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o create mode 100644 detr_tf/custom_ops/ms_deform_attn/test.py create mode 100644 detr_tf/networks/deformable_detr.py create mode 100644 detr_tf/networks/deformable_transformer.py diff --git a/detr_tf/custom_ops/ms_deform_attn/__init__.py b/detr_tf/custom_ops/ms_deform_attn/__init__.py new file mode 100644 index 00000000..f27f6062 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/__init__.py @@ -0,0 +1,25 @@ +import os.path +import tensorflow as tf + +if tf.test.is_built_with_cuda(): + _cuda_op_module = tf.load_op_library(os.path.join( + tf.compat.v1.resource_loader.get_data_files_path(), 'ms_deform_im2col.so')) + ms_deform_im2col = _cuda_op_module.ms_deform_im2col + + + @tf.RegisterGradient("MsDeformIm2col") + def _zero_out_grad(op, grad): + grad_value, grad_sampling_loc, grad_attn_weight = _cuda_op_module.ms_deform_im2col_grad( + op.inputs[0], + op.inputs[1], + op.inputs[2], + op.inputs[3], + op.inputs[4], + grad + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight + + +else: + raise ValueError("Trying to load cuda ms_deform_im2col without cuda support") \ No newline at end of file diff --git a/detr_tf/custom_ops/ms_deform_attn/build.sh b/detr_tf/custom_ops/ms_deform_attn/build.sh new file mode 100755 index 00000000..17a3057b --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/build.sh @@ -0,0 +1,7 @@ +# With tf env activated +TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) +TF_LFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) + +nvcc -std=c++11 -c -o ms_deform_im2col.o ms_deform_im2col.cu.cc ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr + +g++ -std=c++11 -shared -o ms_deform_im2col.so ms_deform_im2col.cc ms_deform_im2col.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]} diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py b/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py new file mode 100644 index 00000000..29b28aa7 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_attn.py @@ -0,0 +1,187 @@ +import tensorflow as tf + +#Python failback of MSDeformAttnFunction + +def MSDeformAttnFunction(values, sampling_locations, attention_weights): + + # for debug and test only, + # need to use cuda version instead + """ + :param values level, (N, H, W, num_heads, head_dim) + :param sampling_locations level, (N, Len_q, num_heads, num_sampling_points, 2) + :param attention_weights N, Len_q, num_heads, num_level, num_sampling_points + """ + + sampling_value_list = [] + for lid_, (value, sl) in enumerate(zip(values, sampling_locations)): + N, h_l, w_l, num_heads, head_dim = tf.unstack(tf.shape(value)) + # N*num_heads, h, w, c + value = tf.reshape(tf.transpose(value, [0, 3, 1, 2, 4]), [N*num_heads, h_l, w_l, head_dim]) + + # N, Len_q, num_heads, num_sampling_points, 2 + sl = 2 * sl - 1 #between (-1, 1) + N, Len_q, num_heads, num_sampling_points, _ = tf.unstack(tf.shape(sl)) + + # N*num_heads, Len_q, num_sampling_points, 2 + sampling_grid_l_ = tf.reshape(tf.transpose(sl, [0, 2, 1, 3, 4]), [N*num_heads, Len_q, num_sampling_points, 2]) + + #N*num_heads, Len_q, num_sampling_points, c + if True: + sampled_values = bilinear_sampler(value, sampling_grid_l_) + else: + sampled_values = nearest_sampler(value, sampling_grid_l_) + + sampling_value_list.append(sampled_values) + + # N*num_heads, Len_q, num_level, num_sampling_points, c + sampling_value = tf.stack(sampling_value_list, axis=2) + # N, num_heads, Len_q, num_level, num_sampling_points, c + sampling_value = tf.reshape(sampling_value, (N, num_heads, Len_q, len(values), num_sampling_points, head_dim)) + # N, Len_q, num_heads, num_level, num_sampling_points, c + sampling_value = tf.transpose(sampling_value, [0, 2, 1, 3, 4, 5]) + # (N, Len_q, num_heads, num_level, num_sampling_points, 1) + attention_weights = tf.expand_dims(attention_weights, -1) + # N, Len_q, num_heads, num_level, num_sampling_points, c + output = attention_weights * sampling_value + # N, Len_q, num_heads, -1, head_dim + output = tf.reshape(output, (N, Len_q, num_heads, -1, head_dim)) + # N, Len_q, num_heads, c + output = tf.reduce_sum(output, axis=3) + + output = tf.reshape(output, (N, Len_q, num_heads*head_dim)) + + return output + + +def within_bounds(x, lower, upper): + lower_tensor = tf.greater_equal(x, lower) + upper_tensor = tf.less_equal(x, upper) + return tf.logical_and(lower_tensor, upper_tensor) + +def bilinear_sampler(image, coords): + ''' Value sampler using tf.gather_nd + Args: + image: tensor with shape (bs, h, w, c) + coords: coordinates tensor with shape (bs, ... , 2), xy-indexing between 0, 1 + + Returns: + sampled tensor with shape (bs, ... , c) + ''' + + #Correspond to padding="zeros" (optimistic : discard only out of bound bilinear coefficient, not the full value) + + with tf.name_scope("bilinear_sampler"): + _, h, w, _ = tf.unstack(tf.shape(image)) + + + gx, gy = tf.unstack(coords, axis=-1) + + # rescale x and y to [0, W-1/H-1] + gx = (gx+1.0)/2.0 * tf.cast(w-1, tf.float32) + gy = (gy+1.0)/2.0 * tf.cast(h-1, tf.float32) + + gx0 = tf.floor(gx) + gx1 = gx0 + 1.0 + gy0 = tf.floor(gy) + gy1 = gy0 + 1.0 + + mx0 = within_bounds(gx0, 0, tf.cast(w, tf.float32)-1) + mx1 = within_bounds(gx1, 0, tf.cast(w, tf.float32)-1) + my0 = within_bounds(gy0, 0, tf.cast(h, tf.float32)-1) + my1 = within_bounds(gy1, 0, tf.cast(h, tf.float32)-1) + + c00 = tf.expand_dims((gy1 - gy)*(gx1 - gx), axis=-1) + c01 = tf.expand_dims((gy1 - gy)*(gx - gx0), axis=-1) + c10 = tf.expand_dims((gy - gy0)*(gx1 - gx), axis=-1) + c11 = tf.expand_dims((gy - gy0)*(gx - gx0), axis=-1) + + #clip for CPU (out_of_bound-error), optionnal on GPU (as corresponding m.. while be zeroed) + gx0 = tf.clip_by_value(gx0, 0, tf.cast(w, tf.float32)-1) + gx1 = tf.clip_by_value(gx1, 0, tf.cast(w, tf.float32)-1) + gy0 = tf.clip_by_value(gy0, 0, tf.cast(h, tf.float32)-1) + gy1 = tf.clip_by_value(gy1, 0, tf.cast(h, tf.float32)-1) + + g00 = tf.stack([gy0, gx0], axis=-1) + g01 = tf.stack([gy0, gx1], axis=-1) + g10 = tf.stack([gy1, gx0], axis=-1) + g11 = tf.stack([gy1, gx1], axis=-1) + + m00 = tf.cast(tf.expand_dims(tf.logical_and(my0, mx0), axis=-1), tf.float32) + m01 = tf.cast(tf.expand_dims(tf.logical_and(my0, mx1), axis=-1), tf.float32) + m10 = tf.cast(tf.expand_dims(tf.logical_and(my1, mx0), axis=-1), tf.float32) + m11 = tf.cast(tf.expand_dims(tf.logical_and(my1, mx1), axis=-1), tf.float32) + + x00 = tf.gather_nd(image, tf.cast(g00, dtype=tf.int32), batch_dims=1) + x01 = tf.gather_nd(image, tf.cast(g01, dtype=tf.int32), batch_dims=1) + x10 = tf.gather_nd(image, tf.cast(g10, dtype=tf.int32), batch_dims=1) + x11 = tf.gather_nd(image, tf.cast(g11, dtype=tf.int32), batch_dims=1) + + output = c00 * x00 * m00 \ + + c01 * x01 * m01 \ + + c10 * x10 * m10 \ + + c11 * x11 * m11 + + return output + + +def nearest_sampler(image, coords): + with tf.name_scope("nearest_sampler"): + _, h, w, _ = tf.unstack(tf.shape(image)) + + gx, gy = tf.unstack(coords, axis=-1) + + # rescale x and y to [0, W-1/H-1] + gx = (gx+1.0)/2.0 * tf.cast(w-1, tf.float32) + gy = (gy+1.0)/2.0 * tf.cast(h-1, tf.float32) + + gx0 = tf.round(gx) + gy0 = tf.round(gy) + + g00 = tf.stack([gy0, gx0], axis=-1) + + return tf.gather_nd(image, tf.cast(g00, dtype=tf.int32), batch_dims=1) + + + +if __name__ == "__main__": + import torch + import torch.nn.functional as F + + import numpy as np + + for i in range(1000): + + test_size = 100 + + grid_size = test_size + feature_len = 1 + batch_size = test_size + + grid_sampling_size = test_size + + values = np.random.rand(batch_size, grid_size, grid_size, feature_len) + + t_values = np.transpose(values, (0, 3, 1, 2) ) + + coords = np.random.rand(batch_size, grid_sampling_size, grid_sampling_size, 2) * 2 - 1 + coords = coords * 1.1 + + values = values.astype(np.float32) + coords = coords.astype(np.float32) + t_values = t_values.astype(np.float32) + + tf_result = bilinear_sampler(values, coords) + tf_result = tf_result.numpy() + + torch_result = F.grid_sample(torch.from_numpy(t_values), torch.from_numpy(coords), + mode='bilinear', padding_mode='zeros', align_corners=True) + + + torch_result = torch_result.view(batch_size, grid_sampling_size, grid_sampling_size, feature_len).numpy() + + diff = np.abs(tf_result - torch_result) + + print("diff", np.amax(diff), np.unravel_index(diff.argmax(), diff.shape)) + + if np.amax(diff) > 1e-3: + break diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc new file mode 100644 index 00000000..afcef79a --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cc @@ -0,0 +1,226 @@ +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +/* +:param values level, (N, H, W, num_heads, head_dim) +:param sampling_locations level, (N, Len_q, num_heads, num_sampling_points, 2) +:param attention_weights N, Len_q, num_heads, num_level, num_sampling_points +*/ + + +REGISTER_OP("MsDeformIm2col") + .Input("value: float") // (N, Len_in, n_heads, d_model//n_heads) + .Input("spatial_shapes: int32") // (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + .Input("level_start_index: int32") // (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + .Input("sampling_loc: float") // (N, Len_q, n_heads, n_levels, n_points, 2) + .Input("attn_weight: float") // (N, Len_q, num_heads, n_level, num_sampling_points) + .Attr("im2col_step:int = 64") + .Output("col: float") // N, Len_q, num_heads*head_dim + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + auto batch_size = c->Dim(c->input(0), 0); + auto num_heads = c->Dim(c->input(0), 2); + auto channels = c->Dim(c->input(0), 3); + auto num_query = c->Dim(c->input(3), 1); + auto outChannels = c->MakeDim(round(c->Value(num_heads)*c->Value(channels))); + c->set_output(0, c->MakeShape({batch_size, num_query, outChannels})); + + return Status::OK(); + }); + + + +REGISTER_OP("MsDeformIm2colGrad") + .Input("value: float") // (N, Len_in, n_heads, d_model//n_heads) + .Input("spatial_shapes: int32") // (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + .Input("level_start_index: int32") // (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + .Input("sampling_loc: float") // (N, Len_q, n_heads, n_levels, n_points, 2) + .Input("attn_weight: float") // (N, Len_q, num_heads, n_level, num_sampling_points) + .Input("grad_output: float") // N, Len_q, num_heads*head_dim + .Attr("im2col_step:int = 64") + .Output("grad_value: float") + .Output("grad_sampling_loc: float") + .Output("grad_attn_weight: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(3)); + c->set_output(2, c->input(4)); + return Status::OK(); + }); + + +void ms_deformable_col2im_cuda(const GPUDevice& d, + const float* grad_col, + const float* value, + const int * spatial_shapes, + const int * level_start_index, + const float * sampling_loc, + const float * attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* grad_value, + float* grad_sampling_loc, + float* grad_attn_weight); + +void ms_deformable_im2col_cuda(const GPUDevice& d, + const float* value, + const int* spatial_shapes, + const int* level_start_index, + const float* sampling_loc, + const float* attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* col); + + + + + +template +class MsDeformIm2colOp : public OpKernel { + public: + explicit MsDeformIm2colOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("im2col_step", &im2col_step_)); + OP_REQUIRES(context, im2col_step_ >= 0, + errors::InvalidArgument("Need im2col_step_ >= 0, got ", + im2col_step_)); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& value = context->input(0); + const Tensor& spatial_shapes = context->input(1); + const Tensor& level_start_index = context->input(2); + const Tensor& sampling_loc = context->input(3); + const Tensor& attn_weight = context->input(4); + + const int batch_size = value.dim_size(0); + const int spatial_size = value.dim_size(1); + const int num_heads = value.dim_size(2); + const int channels = value.dim_size(3); + const int num_levels = spatial_shapes.dim_size(0); + const int num_query = sampling_loc.dim_size(1); + const int num_point = sampling_loc.dim_size(4); + + const int im2col_step = std::min(batch_size, im2col_step_); + + Tensor* output_tensor = nullptr; + + TensorShape output_tensor_shape = TensorShape({batch_size, num_query, num_heads*channels}); + + OP_REQUIRES_OK(context, context->allocate_output(0, output_tensor_shape, &output_tensor)); + auto col = output_tensor->flat(); + + + // Call the cuda kernel launcher + ms_deformable_im2col_cuda(context->eigen_gpu_device(), + value.flat().data(), + spatial_shapes.flat().data(), + level_start_index.flat().data(), + sampling_loc.flat().data(), + attn_weight.flat().data(), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + col.data()); + } + private: + int im2col_step_; +}; + + + +template +class MsDeformIm2colGradOp : public OpKernel { + public: + explicit MsDeformIm2colGradOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("im2col_step", &im2col_step_)); + OP_REQUIRES(context, im2col_step_ >= 0, + errors::InvalidArgument("Need im2col_step_ >= 0, got ", + im2col_step_)); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& value = context->input(0); + const Tensor& spatial_shapes = context->input(1); + const Tensor& level_start_index = context->input(2); + const Tensor& sampling_loc = context->input(3); + const Tensor& attn_weight = context->input(4); + const Tensor& grad_output = context->input(5); + + const int batch_size = value.dim_size(0); + const int spatial_size = value.dim_size(1); + const int num_heads = value.dim_size(2); + const int channels = value.dim_size(3); + const int num_levels = spatial_shapes.dim_size(0); + const int num_query = sampling_loc.dim_size(1); + const int num_point = sampling_loc.dim_size(4); + + Tensor* output_tensor_value = nullptr; + Tensor* output_tensor_sampling_loc = nullptr; + Tensor* output_tensor_attn_weight = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, value.shape(), &output_tensor_value)); + OP_REQUIRES_OK(context, context->allocate_output(1, sampling_loc.shape(), &output_tensor_sampling_loc)); + OP_REQUIRES_OK(context, context->allocate_output(2, attn_weight.shape(), &output_tensor_attn_weight)); + + auto output_flat = output_tensor_value->flat(); + + + // Call the cuda kernel launcher + ms_deformable_col2im_cuda(context->eigen_gpu_device(), + grad_output.flat().data(), + value.flat().data(), + spatial_shapes.flat().data(), + level_start_index.flat().data(), + sampling_loc.flat().data(), + attn_weight.flat().data(), + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + output_tensor_value->template flat().data(), + output_tensor_sampling_loc->template flat().data(), + output_tensor_attn_weight->template flat().data()); + + } + private: + int im2col_step_; +}; + + + + + + +REGISTER_KERNEL_BUILDER(Name("MsDeformIm2col").Device(DEVICE_GPU), MsDeformIm2colOp); + +REGISTER_KERNEL_BUILDER(Name("MsDeformIm2colGrad").Device(DEVICE_GPU), MsDeformIm2colGradOp); \ No newline at end of file diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc new file mode 100644 index 00000000..6410a01b --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.cu.cc @@ -0,0 +1,1353 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +//#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" +#include "tensorflow/core/util/gpu_cuda_alias.h" +//#include "tensorflow/core/util/gpu_device_functions.h" + +typedef Eigen::GpuDevice GPUDevice; + + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = (width-1) * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = (height-1) * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); //- 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); //- 0.5; + + // if (h_im >= 0 && w_im >= 0 && h_im <= spatial_h-1 && w_im <= spatial_w-1) + // { + // const int h_r = round(h_im); + // const int w_r = round(w_im); + // const int w_stride = num_heads * channels; + // const int h_stride = spatial_w * w_stride; + // const int h_ptr_offset = h_r * h_stride; + // const int w_ptr_offset = w_r * w_stride; + // const int base_ptr = m_col * channels + c_col; + + // const int ptr1 = h_ptr_offset + w_ptr_offset + base_ptr; + + // col += data_value_ptr[ptr1]* weight; + // } + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int *data_spatial_shapes, + const int *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * (spatial_h-1); // - 0.5; + const scalar_t w_im = loc_w * (spatial_w-1); // - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +void ms_deformable_im2col_cuda(const GPUDevice &d, + const float* data_value, + const int* data_spatial_shapes, + const int* data_level_start_index, + const float* data_sampling_loc, + const float* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + + +void ms_deformable_col2im_cuda(const GPUDevice &d, + const float* grad_col, + const float* data_value, + const int * data_spatial_shapes, + const int * data_level_start_index, + const float * data_sampling_loc, + const float * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + float* grad_value, + float* grad_sampling_loc, + float* grad_attn_weight) +{ + + cudaMemset(grad_value, 0, batch_size*num_heads*channels*spatial_size*sizeof(float)); + cudaMemset(grad_sampling_loc, 0, batch_size*num_query*num_heads*num_levels*num_point*2*sizeof(float)); + cudaMemset(grad_attn_weight, 0, batch_size*num_query*num_heads*num_levels*num_point*sizeof(float)); + + + + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} diff --git a/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o b/detr_tf/custom_ops/ms_deform_attn/ms_deform_im2col.o new file mode 100644 index 0000000000000000000000000000000000000000..e025baa6841a875f13886df5771c1f8d040ea8ba GIT binary patch literal 223032 zcmeEv3w#vUdGFbo)x#KzMJBckBYCs}BfE~h5@2LsCt0l|YwbcS8584lB_u#Lq<#ni zvC|vu)nknng;WreI`H8_-Ud5$;^emJB5=WN+G6LnaeD_U3#BI4zO8tqIH^_ZZ0%>w=1?%LwaIdkTm^F8N$-#PQ0XX=*SqqtnEuzoJ_N#U|H3F%YIohri5 z<$m#qc$n+XN1mR!$x0VmX}OjDnw2iH(#1%ZSl4$TU20udA-&VOz6Z?>{|jcAbgaM}oau4zCOLxfc*#gS~YhRKyO|jZy8sI=}&fOy}}57_)dy=bPhwE6?=q|M4g_WVPbyL;WpO`p>=N++kncSg>#F2kU~1?q<|& z&Y^*cH*5_1dOvu$x6VvK?X2(BnGf$jF>BWMVn8V6_Me#1J#r&P#+i|87KjlV(s=jQ z_ix-%(YA7uI{a8#1`eff9h--=Vsj9b?>(`o^R|O?yGx$|H$=M9l5%09H;}BV1m=L zJar>L2Z1X&B?;UW!S46V6CpiO5u5d0X1VThS5ItK&yOL=KcIpi;!gy(li&$P7~Opy zcZ1vO{#UT~5SIaO-CIqud3|s@Nwf&|aa;E7VDACqO<6bC`{%*lf8@u4#3WWt6#j$@ zxuF|_+pnR9en5h4H<~brKJ@)U@CPHx^{>TpZWF;fk34x}a6{~{g4#|$-14Eix2d7` zB(?F}opo=|+8(23`#wf@`s&_BQ`WOCVDoJnpd5<@*Q}mxRDTNCk{_{aIivW@5*lkdzTL`>#ciXS?_WhP!LT#m)I+~v9HcE>)Ulz_|yAVme2b3aBtmYWd{${ z9Tsk)WgS|#7_6@dV>Z7J9u+^Je}2d{th;@62RUzLLF?^xN8)wkjntWa%g2Jfe*vX& zq;JdH_xApMHA{4`_f7CEBOwyoQoXEidH=G$x)%WcV6YFc^=;YLR~I|C<+;AC{s;TE zR^iWwzjFNP_$$Jn=Rtd@-3R^LQg7Xly0`vFY%cGs`;mQ*hQ?yY-Ey_}^6l|WPRPEd z52QY;egH^67(NfuJ%n_itpRBm>8fCF`!oDFQh_i#i{X1wylZPK3Au zkA$U~aBxgD{Z?%vVvuXUE^Bi?(|W*ao=21Pz-vU(&aFO5X8lH;kGaF?v~MLw2#z_} zS2v;Tu>I7R{|da(SLa)B4pJ0}&596)0ZJktfsnzq;P4|sS1jl{65Kzc+}nGwZwm=U z0)yE^>01dNJ>2_uAAz@)_d>`L7;9GK?`8@yIddI^JmsHzl}L2%RifFsE#bbcK`_?j z3=7#PM z{Ihy)&Ots5E}7MH8kau{a8*1#t7m&2RV8NZc`zTh7UI^G!tv8HE)nA_K*LuGs6PZk zFNrj)xPp}W2e8T|0jd)XUnxrEUn-!!4H>&qK;419FA1fNLdRSwpw^-9O9Ir@R}iS3 z(9D+vs5{Z{m7-J(ZwSpmw3*D+QExsen3wdB&9j>OS;+Nhoy^ z^Q`WzzE92i-4p+@Cf^TXV3!19npvij@roH4Tee&Z7~i|G|2?h$y&L-9yRQGeYy02x z^uKpa|9dm~-^=TN&)xr?tDk5M=N4PMvwB`Y<^oplQ?`M1xzow5?Covizf~Qw943=H zY%PDyKfpE%+8!?Hort@1VS(GuUU|!MMRvYFPWbuLW%O|GtA0*)RdKT2hh$lAq@0uO zIu)`5B~paTM1!5&!jT71NXYKvE3(0N$yOhDJ>1P@u7q2%r%7+g-Y309BMSCTdW%Lm z#VsEBq)%-HUcNBN@KzR0-qcEcit_R)%6(ImS4~mwpS*k*VLb7cLugDCgg6vgkNc2z zTs`Vsku7caQx8M@gTE)q$R5`evTeL+09*bQu_+le}j4rzWlqtRC zGr_(;`$yGf+h=-$ecP$I+P(;N7wp?b%Y&Z|!g&$* z$sRg!9TrI7Nb*Y@d1s^Ip6JExYD^4dNCTBq0c8rxO zgUdMi#^Uo7t(hZE(uJ_O3(wCfLU zZ+ja`pt`BLxq7vw3Oe8SfVRDVdN<*<_f;tvyQC1Xs6udwX_obF_b}?~!0M5FdD(zV zAum_q@`HaT+OR8b9G4R$yB1Z6&H5sA#gZ+%eg1Xgv2HGga65P9Xr{U=2RYSUDXdf7 zC6I!>Q{81qr@YI+Pj#0heu`#E=%=~`xu5D5M1QJVV6>@jK>&t?T};fW9^A$B zo$@Y`Jf-DS?LzV`HA%?1BGH#+>zu3g&Xui&pj{eJ2<;N&)d-XisgbH$Rz=W0q>9Q}2SNE1 zAp_+@I>@t(bFA~6>dfh-H~k-N?vvF^+4&Mg zIB^!2oi9zQ)Bf|!mm=5&#|Bvj=1X_lZDr?6SvzjK<7_k(TUou7)k|5uG=)PiYsXE;j$4VZ4Kk&duCm+8>ZPn+%Ic-8Uc!tpd0}n( zzUn`{EpvKlquo|kFJ<*oRxf4s(naW{1-E2QFa7<^R$E!Ul+{aFy_D5U7onH7FUXu; zsZPn+%Ic-8Ub+aq^m9FPdg+^XTUou7)k|5ul+{ZYp_iJp%;}}sc3WA!l+{aF zy_D5U7onG4zdmz%>GO75S-q6iOIf{?)k_zlmp*=N=Je7#g;rZxy_D5US-q6iOBbP+ zb`@k!FWqammDNjGy_D5US-o@-dg)gbHjM%UWZ*lLm+iK)dMT@yvU(}2mo7ptwdG_^ zFWq3bmDNj0da3(;#4WKy7Xe4WIESo1bmGoE6!4+$gU<^2U-kFtZuM&*C?ZU7MDO5s zj=CM}lYzpmcoTS)P~e2@gIVNcIfvvabJEf8oN(Yt4C5b^Crkvl=af@SBb*~)#lx4f z{F8G>oQN}4SoQ6nDE-d4a~!-a-X?;XKoj&HuI=5As7yx?+T=v<%5gOJcJJ~L#6?Fd z%ir$XI)bq1C;D&$GJbAL(YY<2<-Pl>d-0^zIHr_|(@SrGI%8JBNnx7`+6QbzbYG~4!{{0x{#3=braP=S30RD^u?VPv| zimmB@%Q?HF7e`L49PeHJJ^@OnQh4OWcyEJSUcKMjkdGK}q!&|uq$KeZi2Xy~cAfJq zl;^1+_##D?7Wr0v4^ka}e*A^;$A_+1l@NWBf0X54*h;wty4Ga@HHhCmts5R~Tc>Y+ zLVu)fv%YzKoBsHg$JVrU=ubSX*Eg)x*R_3a!`ik@cj;VPh}+jc@mSmKo7Znx)4Ju+ z&9}Ec-n#aQ$Jezky1ni3&uzMW^TU-NUwr$c8`j+3{`t-8pLl%HZB@5#+OWCpZ)eA5#h4?p_ElecepeC?xK*0tTfy$g`W{&Isyg1FpA}syrx>5CkCKr3Fm}Px`BjJx*xW# zsW|ERgz|GIe)e%ee>w{3uLytC7XA748$CKw`pv-~y&9;o#2?+m87$%pbb+5eCLPx{ zpHLY*(gGfCzHg70%I#}hI2rT0 za{b_5ms>og42$(Znnw&%6G?5@xIABZVjTn-76{Ai;GM=mxJj&&cs<#8sV}=7x+c8y zj3(tev5pOEVtsm!f91dZ&y?CU0zUZnD=USnP5N<%sNdGn@kEEd;c@-3O{>?nJ^Vz+ zW36i*ZCkzJvC6eiJi3~A{4RaLrWsS!L#4_Mk0sPw{h7r}lAE%sJkq{p^>4RzJl^)` z>P_n(TiwyNZp+%X)t{@p_u;zFZD{!Q2Uf4BT;1^N8&)i)3;e%9{u);JR|1*># z-S#WeZvT#o>5PBPqfe~;?M)E!)vZsqcC;nJv+T1KzRJaQX@Lc}GSKgmipn&*%}l?Q zOYXQL{Z>>~U6Fq8m_BtcmagF6WuIL%eX?exzl!OTGb8<_LB)*pH+}M5w0toArcb`i z^j9%`saXA)MT;+_?;d$<`bu>5XO>i5sJ$fpzxpxE=QgaC`ZS$>7NzKE>y;IM&p3@u z@mk>*%IBWguukXM-(7dbXKuAT<7?=6qP-lSlG*3u(d^l}K7pXDD@UHAligOHnDt$x zy(>R>uI_ZW?X(Mr-{Al?6rRS-51v&JD|c7Sc{b69!Lb)%{=n>lPdkPwJ3^Yk0D9wq zeAy-ZOa1n4{jhfnCifp=itjqoeZuv%vs*uCI)Ff3vmVFX;ceGFM-(xn;7CIh{ef%i zUt+@WL$&G0#o`a#xc`-$rvF_mem4)7XY-nl7mKlcT>j{)ruT})_h;a8&(%$*i^Z2` z;__c-HvL<%IDQQ-|FWQ|yF|P;3zvhQroIxfeKszCKD+5FC1SJ?mw$0>(|<1!-pH@{OXV@05rbdA0uM+@}9oBEB&Xmr?DZ10`Zl@wYHa zjCbZs?$=ynIpW{n~P}PN26-|V}j9epI4q%nv7$TUf_kB>-` zG<*5zG)bo;izXSE|EFblB#`P#bnQ^^N_3t`*vsC3=IEV{y_Z>Xr`vNTshw`$876jG z-Di&0>DbtrrFFVJXOhNW6 zoMzV$AQm$KCM7(4hpZkG9&LP6yVye%i~^Y4-6EnJSH5(q^jZw|((UHT}LWdaD?=TBNyE)n>BG zzGNEby(`hR!;+V~^ZBb+ul`IXHU_7&F`%w;(c0Fv>)Td8($TtZ^`_Rx+8^EU_#>;a z-+A@wC3jRU5?5$cl?kJ2-Msnn)larT=_0GFiXe43ZI`uh4>aTHR(; z0i@e=Cgne!zB49hX>@vC>944&%$)S6HK?hDzoKe!hJ-(j0bL^5Ph$j`dHY|X{S9e# zn^DnEx93cXemZ?;O!U*}^n9Yfc&V=AnRxG~F`!E%`e}?HGopWm zw!5U&ZAL{u-JUZk`swtYG0{(>)ANdcGyd)sb948&O0+F_34ahN}*p=#dDBM2>tZO zbBSbr`UALVk%vbyz1@e_^2EjYGScofQzHLSI?j;Dr`vTVL_WgV&Y#N?$hq_e9~Td#~m4z z_B4hxrLbc}nG*JC59boe`n3DcjHq9cz0zrPno&_tuis3H`t*Cwn5a*?%kzr*qKc&% zG2i(`n#Pc(7WG9Hpi{UTq04QHaapR>Q`jTd>Wl*RMgY!H>eKG> ze4<`)XNLCVuxzC3=zq0WVz|on}eKEYqJC+<&dG4A zabeAT!;Um&wrOxuMjM=rRWjNEc;WJX`E!|ii}TLq>(Jq;%hM%z?Xq+<84q5Tj!92l z{C8siDZV%(y>qdsnvBjD-DqNj&ZO(@$f;sMQW<5gN4QXxBlaiX8s+kvk_(Z06anN_ zaovRSdPH4S1rCBp`Q3$p^r~2zQYyqMgxOaG&c#mo-HRjER8gK%D#TqgsJv=I`2#bl zoK7Q1{?+0PA5|b2Z^~~8j+Ru#ohhY4{3Z?|P{j=s%0Gnwsj9eTLb(qI45;F+3FY_U z00dRcolt%!4nUv86Ur+zD#t%5 zzZ;9G{NpL5LKGl4tSUY+p?tnh<#;INH@lR|Z%Zi^;@`RalM~84^QnGf1X@|;LoUC4 zLj9kCWmWO13FZF=R#U}|6UyHO>!`vvq5OY=6(}}s>hGUG5>@D_#o{fXSQQl$%Kr=~ zQ$_WJ@;5PbRRHLe-vEZ9iprEyAwERYs_;%IA4XMGESgXrUQFd>6Uy(x`8%q>&6MAI zM0HjLkec$NkraC3dLRsWpx-vgd!Z*m-sunz{Qeg5UPuSLSii6pF!*iAZ=okaPOb7E zLmmq~ak&fA;U4(?HPS*)f*e`pZ$MrOJ#o1R&yw~31pHsocxeF7nXD3{vK4}@kG}X5?4%K(A zQQx^nebR1(X{v8)I;!tnqrP*E`ch*3xki2G8uf{L?D{q*QGMqc^_^?fCoZ(>+Z;*toom#0 zu2G-3*{*MMHr02oQQx^neX48Mr<8tF-?>J8=Nk1%x@`K|l1BBNYt(nHQJ-*W*Qaax zQGMqc^_}Y(`Fwtl-_o*~&vTN5yh5iOCFEUn zN>)Oik%0cxd7l#QVmH+mZcWm?ssPs|=m9#~4g5hRi!A=8kH6REB;$`_+Gp`Mef$My zC*u!O@htw*kl!N^=0y3uCW}9u5HWrJTa|)8(v4a9oj(4K%u43JYqR)EkN;BecU=~L zY2oj+6#2a&i@&t+_d?42kGwfq{WpF6*Nzu%;`^7pMp^t#AAcXfL`%dUd9AYen?C-& zk%B++zGd+@ef+(Uf)`{7*Ad{7+Y<_@CJ0n&5xRBM)XmUizcr{;$jFkLoXy{N#Z@D*MY~Ad7)4 z2C^8)VjzowEC#X|$YLOifh-2H7|3EEi-9Z#vKYu>Ad7)42C^8)VjzowEC#X|$YS6> z0|xRIb#)DoD&Das;c@vzx!ZI<54dl2YdJlXM$N^0guh%bGb>y<{(^b$x)>|rv0_&^ z7xn9fQX>NGc;3}d`LW&ZK~9_8`zWm~(P!uS-9-nEp1cTYRr8$0=|B=FgXU+YU-l!49ebI2- z^i|}9`}MJ6h3b{7q2q256D?6=%p+n8fX7bc3+5dUQGQ526c`VPfX^G%-bDT?{dkDd zFY51>iV;BIuPyPq2!Cbfl3&rd4mML>Xc28bx)VKLC-il7c@G{dqvdiV)!@Aon#Ych9&WRWpHt-fL>A-lOCjRMCUQxWjEAwTI)`1EIgdc*tA==!9 zdWFD8EAj<%Ymb#OJ=E5nC>KHWFJ4g;ucCLkk#gBkn(~A9#J7R-^){5r}07p0SkcXdAJ0kcbh93|eUggDB9=~c*Iryx^ zTY={_z>CyfWW%L|Wj*Kv9)G>RQh5dA$G9cGS@5bu1RvzaTtxM5u;GXvYe)MSf8-?7 z*%`mN5%NRh*J<0O)@`)G*W7ZGUQdpHd^DaKsn_1sG#S| zm6$BI@okR>_x*H#TObzLhw-bU8?EtHuzVhP{p2F*KN{*7&+)rS9(8=x=vU1c%QNJg z>X#qtBJ!OA{(UXPmyrK<;GQqj6E5W6SX$Y2O7$vU_c8r3 zy6+y)KhDpGR$fKVN6gh+PlVR@pnTQXZWr}aE{vz1qx{lg{7wyn{w{pIf^u`VGV(74 zkjGM_`xYR@u26j;%CWmgUwDRE10HLPkr=^C@}mSi-s9%T5eeU`41dJbsa{Wyc5_4s z#RE9}xqxG=DUb0!RxHzcUu9(_@!#%TF7GkRdr^N*kFk)Rn-kT(0X-)gdbH9V;Ga3@ zcRTXy-OZ(x7UBJ+C;`Af;Qc)1#cul28T2u$Lz0hjSnrgX3($VSyr|~-GTj$VMfB-n z!@>ZxSKw5cD7i@)`5W-5(E;BiD;^#xpZbiDPew@UY z^jFR`ls`jSocPQbCp-t$&^vhVAvOAeUsv#Ws4-u}!r+6|UEqURO9Ng=aF0~yhn1O_f`Vq1#S`A(S&wz|7Jh%9(0@ZUUTc<|E^OtPcWys6j^*8SF&vb5n=2Sj)$mLBLZqJ*%o7?UTw1jSmrmp8sR=bQT=nDXDR7`S=h(eeeycd9Go6IC z6mnMuIIqVz@A4^{lsoVv^N}$1r%`{^Bf$RfDTUw<6Z~UGC0~N>pnt(U=t-nH(i-r= zP?3X&(8(~KC-}p|fZwL4D7k_A2i-qL_Z_&j9ZB?8je4dx^lzX`8TP7-FW}#-3emWt zp+?~8ZXNY3d^B44crXbc_5Gl}Qs|ACJhOA~l5;GGW`K`SJ>a~{=P6l;w8{rKb*8hi z@LB2y{0j{#ApiMHZ(F`q?SnAx2W_}iV?!6^-DacYN9_dD`H)f>00bESju)7}jSVx1 zztvE+7Ndg zKdqDGY?Wf{WBsFqB)y>zsT}jt9pLwTpD^@xswd2)U!eRWiqg>)7$N;h^ePfYJGUdu zccEWC=@X`JGuPK#EbqdkUEX3p zerOx#w*|r--J>*r9f7|00q-}lTt$oq8qfYn=&PIxqv41AKo5pC5*#gijTk?thCbbZa?BrhJVfI?TN%Lk+so|n%5hZc zQ8SG35WS_G+v8C~6>@x9HJTCow~ba}{2gUNyNUE=LqwC~i-bPS_pN^SgKpONSWg8ixGJ&y)pf4X}pjcSKrf9mLezbnc9+?^NBg`7zHQ-}V^ z>!oyBx`6rCu|p3jn80=@cz;tl?1c;2 zp(f_vF}zod`yppV6$|fQQ&>0*`E~5lSSih`qUK91hp>NOkA4~Nf$y>$M$PZc7O@-k z;|8Th-2ODaj{GO|<0hp)Wc%~iLC@L&Y44PouYo1DipO}vmU(^`=orK?mC64_mz(LR-l3zw``!gi%PvDFB?YP;$gX{u# zs1)ysxE%`kyD)|DMhEb|20m(xM!+!x_9x+o?4C}77kH>fzF;ow-U^IEt8JEasShhK zAcR+ahwaXwZHEFM^xV8|Z9ns)p!optvt@rOp+=@h-Wqy+D5SqWr25cLMCXZis5NfK z4lQZaX}nJ=ME_y0ZI9My)Xxn4U~`Hcs#RnD0(hFM@f`HJ*^PX`T#Ww}vOm=}Y7g^; z`PFWu`Kn3zf_Z=w@WTGI$1mjlDeX@SF4ar$B-o$U_@l=ffM3|3IVtvMs1fplaUau3 zk80gowGVh+k9i`?k$jHq&$Y;d|9;ATrnvpNuF)EI1@U3r{`AZ8IDVx)Ji-24L-R8l zZ-ul!>ume;fpOdZ`~>7tHTzkP;b)`z=dnLo@ApKtXI+2;??oMSA^USqD^koS=4_?; zMZIHx4(=hoRze%uz8r1Jq54Im#X0o;g8e!8a!&YLI?XSx7qAPBauk#Osd#si{b@@3 zQ_);Je}a9*^C8G9zsI6RHS2LVSTKWL+b@LqJ?3r79Kh+vJgWusEXMO#la(Hoc7IQ3VK3FA`6C3_+Y`E(@}3&= z$PSE;>`O|?Z-?g#=DC}_TPP1ZbQjs7urqH$zM)>MQP2YXf6o!T4dQ4O+n-~Wz5(Cp z^gQg&oqWF;aKg?M&E=e8{A@&1AR7Q00k zp&^~xx9!eUdo$I}l=g=L{#hgE>yEwYMKiEJ20nHh1NxTd8Yz!pTqtbSb%Ze>hpw_fIheQ6tkV>5b%x=yx=d% zH=(L$2@cqk1;iKi`-V?-1!APXk=N&q6<@{tT4t^Q-4pChGZw~ku*YmV{)Fbc!8{nB z0r;OCxU|{8BhmdQn18Y{m}E01o%xc!l&nwi@09kYk!)|)n8|jg8(E4kt~0TAy^O9!}?_U8$;mleE1op{&nn5qY!wQ5HYkr#_#cGXuu7JP<$W_eQ|8|)r!-=$K_cFG)oACm%yi*B|K+)PL(+b< z?E~ny68(!B=f*r+v)fk+~@=wbb?Om|{@D$2nKL}k2yhpyy z@Zmi|@D`2k&T&&ZX6aAGI70kab%?v~0{@nRZ`r;L=ivF#-Gb{o&mUDduMX%mo}gmr z1Rp(54s3cJ?@jrs_8JP~3-c>XTngsxHC%)z!27m~Q{!#o|NVOn5uj8ULyXt@7LZS@ z@kT;#mvXszKjcU>?A2UrWWI{p!Tjf~3eIcal=KaKAL9l6t>@8S7vDFkNYB8}@fMAH zzc7RJz5v}Gf__E2h8W`VP&w&g@K3-O@X`D0twt_4s|u-HX!r7a9OD?4^^5Frv|9D% zQoVYAK*tms^t-o&@)*Y`-#4Ux8RH;=V;oi1IPSH_p#U-a(LXo{^beKaM$g;*VmzJJ zb37i5!%Ou=$Qj2Dy1y{sD-7tB&>vMio`i7(xPE*bdnOx4eM3M81wcQsMz$}5m5Z@mQX}4)90tyf8_B6aH;^ zuOK-@{uP~4$jjhq@{?F{Xv^u@E(jptx8>B9H^`aCk~5J9Ig5q@wmg-hy?izFAj=c@ zha@OL{#I7Y`#UXtB>i-IjW_8z$dBafi1`AtHvho?8+w7{333%=_#)a1I_EX1x8Uz< zmG?^}eb^oa|AKnt|Jtj)*(vMyehDe$ZKsd(=K6BVW4%bqrSaw-DvuaD8z{x|9#DYj z$MS@6`1!uMvz=0h?~QVczDM?Rxpr?CQpg{ur_lS)k~>4L-=IB9-gdUgc1k53LqYBb z<9&z9@%|htB>A)GZqwDFb4bdet+z$UrmLiLeK;^a>?Qr%#qim5wdrZo8|%lQ-(bMknhTqZ)4&2(tnWzu`En++iQzv?nUzg?=c%LaZ zQCbh?KQ&7D-^u)*@P6UvUGVlKkauAF~Ce(4Ri?-yLi&Lpl7SyHyX>v;Cp)bG2a{@aM7r0RB+j^5fCE$g^{{ zVfh0dJa6=vzL@+fH9RkK>)mEQwTJx&{2ca`l;S@5L4OQ-z<;u!8hNbW55oZp`BPrh z-N?i5L*?*;!c|6j+s_03@YA@WrcX!OpbqID((_9Lm_O_Hd&{&3ptrON_~%Nc|5Ks) zCHCX=SFrz6`wr%Zv_2>OyHV{G=-1i$BbCE}ah^YJWWUSt5S8!L;de%QC-&n|`UCQZ z#y$Z(Tst85cB({?y%m zhM(vEn7@oQse<5h8{|KIc8;6zR;WjtUxYk74f_8W@D|*leoMlQ`UF4r&vXqF{%ee( zTG#`?|Ijj&!+$!2=WYVNPeT6!-+L%ec71aG$*w@SOHna@cJa7@XVTZ@@GmnyZGRv9 zdn(4aK5A5gKj!E*A1$qf_YnT;8tu(6`$Y#NA3ZQ$QR5L+DRTyxA7I`OC?k-{hzdZ0Q6VfvOd-inE%Vn zcL?9Gr(Niu_T$XA=zE&VZ?OG~=pWNJtUufS#eo&5&;CW;p9B9P;DBEdctJblS0p-< zUy<}(!?t1Sp8UV`oCdg58ZY*%9Y-E;IDW-U)@{3lpd>)9Lmiuzd5%@)Ob&P))4R}dDF#-0gp*N+RTmCxuqg#kCk7FH} z_%31+8^I51o+i9B)Q3-X^Ze7OA^gE#y7+FS*q@Wb{!zu;%zhYm^ZX{17b(V1rGIl^ zKJi)6QoOzb`#}F=^OsZJJ+R;*O0iFEkop}vn!6c!=+iuEAL~oG?Ucv5%)XrP5A}NA z=qMbZcz;WMkq5sc42PD8oJaQTHGkKI`foUX&`>MUVU^ITh@V!$|A`4?4Da7YA@|+5 zEI+909>M!=8iSPD{u1~>V>B+sc#7e|{6#2%7VN*-cZ9~HgibJgaewH-PWUz1ABuS> z_TNbVi|uE^^Jt+G{?A)5E@ckZ>3WHd@Q+d+{?PsojB~yF*usrS;Sar&>Fgd@xR>%v z(Y+V)*Q1qgA-`zU{D9@cmP2tg?`hO)5W63Iic+-8^~%h;9_nXlKu7lq^rzf^5H;_l z{@@?|Hy8U!2mTMIHoc)oq2~rMjzbvd{CV)7DhhT$VcaNxu}C$m2Jn1Pgx-A>DfY2g z?S_8&BDbr({VggN+LO?a@Lxgr@e1>vsX&hS&loh221kosbRX^Ylidcts(`+OUyarY zW8WZuYB*{Cj#g^Uw}&fr+_43fKYn{QMz z6|NSVkMsT=SiqRSO8A8N_eAGaHk{(vFB$(fycmBHoJ-+8?gIbfy>7wV2!1bcC&MY= zPfUSNyAR!no^7YpQwsRUf%BLJr}q6J3=eSrY>yE4#r?*{FTY9S8#^lDI}7n`?;7^T z{l$Y{!u_Bqd5_0EdQ>nSh51d1|ImN4=?VUI1@izrc0hWMQ#sT1Q%u)XIHP7i_OB2< zli}M2;Asb*q~|1Xni9?l@P*#^BZBK}Ax&R8yq=O@{11W`bkvEC!u-YO3C=U%I{f^g zpM=YNis7*N6Z)GV!1}4Ar%gwjepAuSrdJ-55Be_!t|{R$lKJv1wqHSk!2erLJhFI_~5B3G|(U{X2QTW#%=|Bk(h) z^m~S|jsBJr^c^bJ-&!aSe_sm;+b@q7 z`Dcr8Pu4Hsq;V0xd0Z*}owIm*0EJ!~&o*8GZwul1485gO@CyEM;Cq*PKZW_PP1l6} zcswcnP3VVu2md(zS@LefM{r>rPJgNW5dQ!K_-{I0Jgy7%hxIpxi==3B`h`(GWq;HU zkH_i<`>zOrWwu{s0(>?ebEvx{{`4cy`@aAJ=$|4VaXfZXy;Q%P-A^8qZ`3=1ZWttf ziSnQ$KLM_#Og;tvPlr#cx!yGKI8pw_uMMLgq8qm}C0r-V=sx66f%=5}CHoz1_-wco z?dzqekNs#X36uF_-xoU4a`GQSg1lqD5c4UIePN_O^6m8w%*(6UpAtGt3-LkhZ=-qa z%)M1|9jJca=qNUzc_Xi#Rl^uhP3Y7@su!u+Ll}HM^3)!ZfafExQE|dPv{KRXWB5_F z4UcwV1DY56b55_+R|7tp7#5c>l`?9C%s(OVpcFeYOa^kU#edsFV%nL_#wxy^`oB)`9au! z1qaApP5ML9^@G&wa{bEohfLQG;`l#eYsdlOwf!RG2cdCK+YbWg2OJ=@eyRU0{27z@ zLHzW+-IRWiQ|tht{R!b8qF%awkjvr!&|_KuM+W>Kbaryq|B>~7Wc?q0+E;eG7Vnq; zNc&H8oVSzpf0$YS2k~Wg|49&N$nHPM`ag&bvinaie*a01ZO6xK{|D_eiTgiX)_xP1 zA1KGZ6W&jvnus5Ue3^MQ@54y&e_WgF|H$F}DU&gX}a{*PC9{X21A#>D*?*VyAv z@32N_>cTVzv=GBxCsA|PW?>7e^kxmnZge-rT@tB z$Bkc$j;8cOO!Ob6=ZE0$cXapta z%i(|U#IpW}4EP^Lmf}60{r)BU{Y!TL#eYcugSZ*;ncaVp^*>>9Zeut~rkAQs#mjB@z%m3g?@;}($TTNm=ru7}hwC%@< zcCjFPe;W4Vq;|i(jy zliNejjW#Ltol6noUz8%9PW+3zIbMMi{~}E7B*wqMeTsiU@ksdlc;qLIe}VFk68~Z) z+_YZAzXHrfZ|7E0EizkAM`?8U@Hzr zbYM=kVG;5a zw^_!o=+=6eFNHRn-WZ4Lb&6k)_%RmW+i@`zw|u_Ljtf%kYC!$+GV@x()70Nzlso|Z zC&j<;bNmb71 z3dg^o=SaR6VH_;q#D|DyY{wy~Mm{0_g^=YJjDO*E1K$bpFLKVv_!oxzWJ>%C9h%** zn&VQ=kIHzpHCjar@#S&E-Q@Tep?lCA#{mYvB*eeaYDqqEUd{YEqzz@p0vZ1zv`fan zIL2{ViWDtZ=1Vw`c|86_^MaLluEd&_&ZqVh;$MvH%LyNxM*IsI z|2JaA3l^d0hz|IUdlk_;A^yc|jt4B_`wxCY$a81zr_}zg8t3f9Xk3bsmk|Gg&YwAQ zgvO(Ueviu&;$K84js?~|{CQ$x{EKZ6fBJsC$0+Zl_!m7UpI3u87^R;hdiR7V{x{B@ zG28LIImf~1B|IQLE8!r%Xl1&>wkyJdo+rt zkvj*^3;KTDiGLBg?>V}U^J>t)cOK$BGo4fSBgl9cgGKe8N+mE(@h{5$o#MsX@h(&| z_#)Z~M$H>&T={EHO~Z*u$#W%}_ins1@;-YPB>|Dwr`e*yj-M*I@= ziXYMMOGtxu+zZwGJk!sPhXFc*A1ICq%42^B`9<_##J^BXEAGXFbBNjz_aaxf<7c>2 z;zV39{zd3M8dp4yi5=e~^u|k0{0qb{LHvuo81}^Y7na?JbGsy5PFxl{eu)+T!s1V1 z#mA8GFQ%s3WsiSh_Oo6>++eAHtvDAH|3X3hJSYBzeO`^(Amd$xz6CrX{)HV6!+7e8 zLfnsYX=0E^vC4^mVZ;bih%5X8=?}yUrf@qHFSvzMd_N)WeaMMr|FgYM_Wv5}Ux)o~ z+5J~xzQ}Pf&TIEaWE|oI`>|b+JuWosJAUm+?su>3AMqDjrC)&KqCp>-zB#DS;MkAW zI32srjvHXn-|jDBwhOSj(~o06PSPKbC#63dA5K51_9I~xa+J_t7BP_4ige zt_$^Ne+NKsKlBs#pR`jR_G-JI((~Xk%=fK+9RG=h&k69^cy#Q3ryrf#ar`G!)6JjQ zpG`NrpDFz(mOdgo94{Jz3(ZV~E582T!uOr;04{(>J01eo-^EV6zZE)f2@|-nJ)#^b z+2OZRinwX?p39H#%imO;XooxN@0LA|_~Qn6oBAy?Pr_c4=kokD+2Kd6cnEA4d-VAFyJa6Et^&okobGo36YTHM3~7HG`90_# z@hAV=5Uf60-&Ydt@3Gx8!nyv$^LJ+SQ+}*Ce~{CodHX1}&)?}b`U{a?kMmI2K008X z$AkF$WB}p(J{Q>~<)g)p-3@(Cb|c3Rgxx65gLP{%n46CcgH?&ht7Yc z_$S~?=C}0E-$DO0|C{psogK{g*bhf~;V|MDGTe%F4$o4e9?mba&*K4obkuXsf0gGr zfDYC<#7>+^>-?Puz2v~ip1%{1lbPZ3cZ`N{`}`f4IIyooWhMa#BR#Qvl|7LwkD-S+3w?$!pa`JSBj zA+Fg5+$a0GO{Sy=MbX(nEHFancUbej8g1hc$p_9op#F-Ehc>>F6aK2c*ZdTwGZaU% z8j0wkeSFsm(ef_BU&e-q5r3Mm7RuXpM2*%tEbZ)_49D?M=UX%m^cy2PIuerO!24Gx`RUQTc=IV9#T?@L zuzMjV^gglD-O}&Z6Dli~cJ_R#Pxh}K@F4!AQCdp&F!T!a65^~${Z|CKkRC+b+6Nv+ z3VT}W5yV$7XZ`63w2=Mm9`G*5{UTAk$w&3d%zBnnjEmciY73|xtg9au99MGpPCO6) z2|bPYk2am*r^a&>U-EA;4(~kGnBRzcLD9UIpObNXhl&s{z6a&2>KkCa!haVvAArAt z#wXzvp~6qoeW8ug%xl#dFdu$?ReMvm$LDY3`S38}Uqb(cDrVz;Jnp1H@$Bdv7dkKY zyNiQlAsy^`Pz+J4*#@fSlje6Q-Lx5~D&Mak=o zKUKy*#W*B<&iOqTfHUgAXWQ2jzPSBKaFU&!0;ehAtTAo)Nc_mYuKfh~-FH?&08RLB zJshXmOK`RjoFer412q1zV(Evo)XM=YcPM(4(MF2y+wFHoPSz(pmdn^9NIsCJf82GFUvXJ zaYNo3>v`)u7M!C?^FlnwdJpFZyh!hl28>I^kny}c2N35Nt)%k2KnpcM{`(4kPQ%QS z>z}OQ`o_wKxZI$)*woIMZVL|356=_!175_xZflbDj_aIi!xkK)_V~MH|JL}mb=CZQ z7aj-xoO2WG@mu(bk9V&P#{oGmi61(@ho28M^7voZettlXPr?B_(0c{=SiyLI9p@*3 zutWzQ|APS!!iQ4-_P7@`QhmT#W7SK7Lxk!WZ$cYpIyn7X_$NNdq4sUO+3;BBK7bD0 z+^#9{1^Ou<0s3$C^Ehq#Xqzm0v;}2(txcB${Cr4`pZX{M7g|lJ9QV&_xPL2NyD)w% z`OV@Zp(zp{=5dD09zWn!RUY>c$p!evrqj=%e+YkjEc~DDmhDSEt8XZj__Odr{4d){ zfyco|GCcJS-T;oUf&5E+fIi*)yjI8XfG zMY}f_X}<0EQlwuvANms$Q0Q;%i+msUO@RF)<}bgE`awm@r}k#r=W)CX{{+W5Fg9@g z`1u>HUfa*X@eYdZxQ{aih5!Dz|DjY2ev~)|u*>N;;e6t(|KmTb|HHos^n(A$mGytP zi)_0k>;Is67p>Db>g12g`afP#6lKVJ8T=oEm~i~3@PFV%ne~5U{U2HX2mC*gtp5Y{ zN7nywIsG4`e<$&O=x{Mj>i_7#`O{=qV}1|+$C(tKg`9r!X54%a)m6m-5yWWp- zI=>YbY!~dW2iTqxS{>U@Xzu_&XWMHD_ARw*+l6=@{<{v?kJ6rX>_EqFa>nv|q}sDX zJRY1s1oNByH{G&b%YS0qbB^76O8P@=``ohY9_h0E7p)8r{E`w5%f7S6DKu&Kl6}km z3(KC(Bl|JX4)|?*SK6^;zj8eXU)J~&>|YwMJbx4TbMPe3!L_JF3jNTo9}tn3qetaAfhcFW$oTrbM(g*%g!TKd zm)57Sejo7A`n`D=_Y>Fe!<9MVXLZ%6N?yM=?!@~1a{ahT>0ig!@3mh;{vqs_r}Q8A zyvW-ovtI{FVb9Vzk&)!}`w+R6uzqiyV|MVEm;7%I9xr8*Wd(^uQ>-Vj=&FlB%cfA30HnG>G4j*tBx&>wVuBCW5}`J;^(?+ok%S;KlTdW`(+ z1%Rg-`GUCz2TJAs@n)g}?;pRw`h6$!{YPKFKm9B8n_a&@)&RQXf(~mT_u#9a`k1et z#OC=+Yrj5duipzD@~D~~xqiP}?tgXG?+srI@g??sVxK(L?@iuE+Y>dvGad zJ*xJ5a{b=;YV!KM@$OdK$GMUJWbK>(?!V)FzpJ7DICVEK{2#bq-QTUrM$gPS!4KyX&g#O;qe16Z_zlM^?O4{v!aLA z?;(KR9@7gCE7tF|?{Iy4eb-*k#yUE_(0r!j`jOmbdPo4<`uvb*6)wiaO$kzn+#tN=R`C7 zSpV$`#IXLGLytkej9P}zhDSAh_o6&#!y^V>VmyMMF#dhCf4q|P_f+sj`CLeXORn1s zqetS$f^Qpo=kpU8PORUj!fBMzxOyD;1}ykOQLHD4(z{R`l+EgOiu@H?H%v|?Y}@d;G7D+r2wAR@1r4nj};?0 z<+{BH&4;|uJ^-7Jv_Aps_bKpU{a!r<`ypa1!0iM$n=Lra=tBgTr#ZTc;01iTDnyG2 zy*QuRC;AQJd1w7zt7Z5cI86CH8R)lM^5fzpc$Nmd1OfIRB=cpI_BSNLqs^HV9%ubN z1+IveLbnndp07Zz0Z%B3bS|ymC65LCkYJKKqPKzC`&JS*~4+c^Q?raJjR7udS%0dJ!$k`u9~M z1~n|+z`j{~{eI95yxVq#%}>035BoE2zn)g{Ala=ocpmAQZkhhPhRPi~)@aaB4*f^# z->~0oJI&CEu2?s$);YE9GD8R4q?i1xKW+QA9_Al*FvRUwLnlPj{(3*B*7~C`^kr5# z={f8lkn4>Sf1|Mfbh%zh@=Es|`-|W6$m={_d!036#$Y~*=ikNhqc`9O|Hdr)tc}YR z?JvE!KNIVZ$?J{=_RV1(QZX?vN4iuQ#`5DZzF*oCko&U#?RzK>!v3Q6u<@%6qxrg@)+OWntM3{V{syc&T6Po8oBK=Le;02r{K7~_ z&LKZx{cytf%Zt$teZTB-CEH)Fa4zP=%_dZJP?@SLApSl)# z+CN>5{x$3?;`ubbKf?qV^Jlr={jg%{{C({~{yr7fzp3Lrg7uPt3h*oD$0pGuSMN4l zH2+?Y=ZBRJ?B6wgOrL0TCF&LEqeoptN9;Q;n~(flwQVQLf!8sg8)-iF^HRO|{^^)3 zhrJ6uMf<0%^SiBmt7S$vdN6%zb6sGr@5RAB#{#DAF}`zs3c z6FkGHNA@K7wF>6eHWU8oJLOkUPV*my2;f!p18(H=vA>GPp_<>p{epSjn!|4@MTDQ5;NB3*E9OLKbX43dk598n4igplxz_!2Uce#;=TvIvA zFFjwb2w7fc(s&5JlD}>F#`_!e3wu>>mo?s2me&J0rxxLP?7Mvy_E#>+qh11e#P^^q z$F`p%Zhu+dDVx@J%G!DihxSi@mE;xn*TaNA?4Qo(xDTq4&+}iL|6NJ@hoi{Jx{378uzn3W<1AmVE^QsG%nabg7!a0G!j42 z1N#Q;D{~Z0W%w~pKkc`6?6ITrJo}!|%@~g|XDN^e6^#AUWt8{8?(q?w-2-Hg0Ux18 zxUulva&y@(%J-P1(1*EmMAJ^zSLnY58XWlFNqB|b^dR$zC?%-r-G zwI4P2uzskgr!oH&1AzZ;FmE{aHBN21OFcJOq#8pS+6xL}y@3?-a%q3$DM}%I|669h zigA0vpWi|MLzu_yr2eu0R@RG#zDDE6yyynDhwc5;^)&s({&`iaCb^UEmK|2Zc33@GFVO#Zp7}A#`;0ZrB4KaM?S!?oLG!SyEbr2|)9AkdH73aSz5Fyy#7X8gtF9>lD;QR{G z1bfWvm2@w50e_<0mWypPe^K!LsaAb6%CWEdA*P$N|Jc|Ayb%2!g1?5~OWtQ(@3rV> zv`YG!pgZq_Z*+lA-6HgPNk3;lJnLm=A9X#TC;oHx#anWS_eMMP7WVx^&cW9?&5Vaz zg}DOwhy8NwIi&aat`Cee{o-~R>38@G$QmJiDc4)@J+@81AD5#X```C0A-M7&zhFVo z@5dJXp0Vh+M$*rL3v`do0sSnvjQg1m-^1(?ELf&$YbAfOo~M1{nt_Sv8lQ6J`hKc%9E(!tb*p_+QcPb!hiA@P_+Z^_MBd{E>DrdhPduFs%L8q3Tc4 z{fJS@`Z&IxNc_Y5sly87=X<9rD^Jn)Gv+LT@0J^EzDlsiy3vkfk2&^Ns8K_CLp|_M z?;pqhvhA)=Bhe9dkG20*t0ua^Zfaz`hyA{HQ~q7r{|)#f8q_xIA^7(QzUOxtW#B*Y zyrPxrNO5krv|H?PDn=V|NPPdL6+aqwO+eZ~Cjk)IK^%7lcF>vMq&(ikEn~1-?C)LU z`!6?^+WRjl{xjZNM+DygQH|^uJ#N2X|K(rdeySZbM%oknnr=G5e$sNEM?1H=L%BVa zMzv@6$aB9tUCH*75{7&w*iTA7<>9XxtUI)#UB`X`|6l^Bct>|j zdj~Up!vA5#kl&jgEYWA99kQcFB4Ke?My##%lOY`4hs332*5vF>WFG~A2S^_tL zpzlw~k4y8#YUC5`C(2J|Kb3+%6TUyS?I&%;Ou_b(uA?8=PvKIU4?Ke9t2Ou2`qe?) z=Xg&QIpHtqn4eL-a@APO^R1|{1a{HKK}R`1J8n|`j}q)BFY>ctKT*1q=4)X(&!(2w ztLXbv2?y*?7vV!K_Mq^Ck2S;~)F8KP37nJ+j_XQs2qnpOT&` zx9wy*?$#vVpQ7GJ+JA}XW&oZj=EVhb>%%aAz&~nRBlU~#v#YJ~Q@&tcZ8NRc!G5yF zuLyY`egTHVv7e4gzOuigRlNi!#Z8jmF}p+7zENI>;`mg__EV@4@-ss}cuc4HA;qck zf!^zDxSZ&4rPi-jzvQ{sPQp z^u+ZmUy{9K`1aE{@cn_bmrUQAG!Di13Fjk*j%Mxkt1#BD78aBJx8wo?1SOc404-(u5D+f!ga)HDRic`OMd- z9B~r3UFc6V0Q<@OS3k<(|B`w}>~4j8AkLFXb}4<=x)=R<=ZT{m2~Q2t(96_*1D@Au z{##~7e?L#Q7pXs0E4SvOKEOlYpH^A(QN4w~Kh^RveWUMB zt@SIc+k(%re#Lx`^{acsY%euGh528?_owDrf-h))zs>6(m_Iu9lIfT5IdOu<;rE&h zE@?+#{e!-sh5b=F34HWDt^K`eX#za3ldO1EWIv3aQZfIv;kp;y(L7tiBeZvVkzbWl z!g6z~7-(VoiO|jc9REmx^@RD0r}-)51M&4Ma|6>4?_1^vdtQrmn#UL(i(V7fAC1pj z_IGCD(e|+ZebpGD@xOufE9ebrFQK5twm)os zT3T7-3*h`2+g{phjFZD6zD{M?N$4N^MDL+aBxL7wlD>H!-+y}p{z~sIqlNl|z0;*5 z#rI~1a**cR{lShQJ%n{1%x95e{#H%)N**u&jAH%z6bU};oi@S~?46&lrab1cgFhvE zCp5=;Zg#0<=Y-!zIsET*f{h-xckUPzejnz^mc3CbMkwyiRQApU`z2)AFNS5mguV*< zWii^@;YzY!j;5^t9L?>A{zANi+(AyKwqM!+7wi}Lo&f)3$bRX+wDyZ}Kj>eC^&i`Q z=`T&PU$ni*t881AZs_%guTd_RGzLZ*j5qOY<6??U&F>G*?9S%ls+q7u4hR65D=3p6r*- zq&(&y7+=EvyHxunME1)q&i=d6qiokA9!6;=+b;vsuEqQr769~zW%nZ92A;$ETc~sg z?SE9P{hL_NV7`Mr202BXTFsAP;r(l^|3pjHrJ{Ap^`9PD59>58j8n^7L&H}v|B-fc znW+-}u>J#G$#x9c53t`g(x0>rq!V3((dKHM#`~l~@`Ly}k#SxRs^R_B`TD`;QqnhC zx1kF>k98FJ-MV~^?3eCT`vvVK*)PbS&wi2Pb?lejnB@TP6YUpi4?FhD0NEjv+AqKt z>=)3L#~rs{P!2lqb8d`#E#VdGKRV$jzW%eW(TzOp#!})liht$f@%UxAZ8zHXIqv)E zzO(*g#nS~Hx^cSp}td z|G3QMA&QS^;T3hCS+UzwY^(NZ{eQkQ_l5w`zW?w4 z{ugr3oHOS;XZwD?=ewP`+C$WjV`ad6;Lnk=^m}XD#7fFTuHEJ6*$aKWLTNi5;_-f2LG@Ul>oTNx55!9(*F&E<%un_IOiTNJrkN0z;Qu)s zqcpS6w)Qo@MEw%|!>#cg?Qh;s{TT5708jDRyxCum{hi|QKC$YZ)MsoxCgXS8++Pbh z7w^x3_z@uerEes`b;kWsfNPT4S zBPq?dSzkTycog}6774_k4d^%hmt%g$4d^#gn%}ejk$CO}8ta=1zYyry@3!| zVf`y?`tH-99}%5v(3_yo^#9c2zfbkO!I6hP^APyTttHSShz{I!5FY|9T~#Ob4BIX$ z{No*T|JXxV6X4si!z@K9>@3*#a!P-ae&45ScT3<1_7Bm2T>$s9dIMm8&|_Nr|8DCe z{_BaaWc7B~`c6hK>D^}0^_@e!|MQTguk=-0?PEW!Nb4(mNnbfOaDK?SjWrvA*GPN8 zMlqmYRItD=zeZmHJbecIG-v29eeOGPKPdOJ&6IBMtA7|fJ)U17PSIc5?Du?P{aV4o zasJ2s1jz&VeTg0`Ewkg<2K+CX__eg$dF%oEJV=6~r(nN;ymcem8#wRM+E-C&?RS~; zL;V4)AN&C2;6Eg1_G_p<$p7LGSm!iQZrRPYFVInsCsw#%C5EgOM9c>iH$eh1>$`V!DX#tad^w#(LE7My;52z+{o z_%+Ozg;z)WcR`;)9Gm~tIJN=(WdWT_1YV^mu4@=lien?W`F51&0gt93KS+Oh=wXyw z@oS=Go2V2Wx68;6(qAn9M(L(hg@u>!gIoGbMsE@0?E0=n8!h~+vfi^t@m1;gwKk4^ zz)zO;e`S8JV*9~`0qW<=<4=oU>s#n2y$k+1+8^k5DrkN2`!iIIaaeHVN7{ruS-;=%1nqw(so}#b#_oww0`+c5AQ~RN>JVNJ( zwqaxI$i%Bn)G&TS;QUWD+7rW$rS-zie)^pp^0yG)kX{lPPkM=buxc#&fqAK-{uH@# zp}=Xm&_8&P*1Of;(TsBFDYw(S4aA?Z1DnQJ`7&GYWDl$}DPPwTzyrvB;37ON-R{4y znabJ0Pc1pvjtAaWl!-@U3yxallK+m|P;XM9b^d1~;QY|Ky}Ey z{;`Xw-ze{=p$d%;g3CvN-}=-HP4TArvB{%IEe0}rfqMt}EGe_86|DD?yD zfcYW*$dLv>*S3A`wG-$*>^D{Y`Cl;hNYTl09M&5PxW& zGtB7mlK-gHFP3`J>JP<#qX+PxnFi{^&UXXpWx#_`*8Cn#jk4wy@3FP!v)BJ`cs%t- zx8{}Bzek8lN9Mdw-;r(h{GLFH_fMuc=9~J#HzTcw9WwyA2>yghi?1vr`3w37_+F2F ze}6Of5y0J!dcfblUE8d@|3i!ZVGpqQQU>2Jf42Oc>aMoyKU`#`{cCCeoaurNGH8t!vcSTTr2vY zw(22HwZf8%Y4wkssD!-K{4UAI>9{4(Bf#i)&dhC)qE5=MutU)~kFT=lL++!`fPDO~ z@@L4;{qmiF$M4rUXugBn`=hY`*OT1GPm!M=INknl%gY1jKibmgKg!8YD;EulUpgPp zJ=6YQCJ4`qAn8&1e-Xc=p&i<9|9^^KI>Y|2RS^GJmWf~L+YY&!`=Es0P_~86P)kx%_e{8hokM=`qAO0vyo`Kwm{Rr~QL6Q&2&Q0&R7#0Bi z3P|9U$d94^Z!f3&<&aah2GAbfe@}U$YmVn4pY}Us&Jnm4;q)5bKWWE94I@1L9QBH& z^!Wnlo0Koj#6v-z!sCkcOkDGpjQm`?5%t)lm-z_^{a(Jkf738$bNPV%ANwluB-iZi zp?QDh4r9~p8TjPctahqd^xTWpo(=nbw=RX!yT6l^53*)Bq_tv-} zx6yo(+%^{TK<7W~bN8_Sqa5-x^pD-(Um~zkM&6x?^TGVke3fSG;hX%rg`Xn*-u#R` zdKpR2;9u!^DhBzfXDd>jn@Aw z?fw~gb3lGR8MlP?2g%P1((?2FUi=cCKcxL1^z*gxOMs*E_jC@6J^(KG0q@Rj11uzc{JpK<=a7Plck z_dQJOFtxY8lJEuRL2SLO%72?>=ZF3KiIyG%@NaA*ycYglpVPVfBV|siUv#9+LwYje zkluL0(Sviisa=+Q-J7EA$&UYd81hqZpMAbRV~3LZ##?>{lAFk$FQ@(O+f%b1CVu{C z{~D4Dewr{gTK0dOCtXGD$PWgZD3wyPEdNArf4fx=aSWSLPVqE1SZTlgdvzzze^lW2 z>*&4u6|jE=VE-QuJJw+P|DnCs`1+r+#&@*;&>J)!*kADbnHu!}03J*Bf9v<>hEczr z==ZTUl4~K?r{l$rJz(3@H}}`yNOCpKpJD;%{6}3c<+uBHfUnW-%T;V8c^kjucMI|q z&qOK4S9v+5?16f+zv2DMqQ#Cs8k@55#$oXp*dfp^$HXdxoP{o^`2PR9U_*0fE_GM*L&NLUx4C&*};IdE*W_q_Ft$Ui(vmvfX+;e zvk?6prGXADIs-mhc!cL$=ZQ@xBIYoUYwC|HMdsk3ild`CdmFK-vmCc(`a_ z9#S9m0DclHFSgpxtmnz|_lNW8{=oSBTTfYksc)iX2gmt>6!k+&$N3~$X!$tRg(-EUv!u~Opp72sH&MQ*N`f59oPVMcRZT0t2 z+yC&=qyCrd`~6d~4}Gudgp1hFyo$06D z4HbJ5*ug1ppG!hqRxR;K#9`U}qTi3}-gE--`?FEZ zUP`xz{s+?i-)r?B;~=<~SpPO!kEwU{!(NU3&(+2&^~v7jiWTbyaEXe3QON5ebbgNGx&kubEA1ad48tvs7!h27tN$sw)TIdB0mXn zXEe^N*3@4oQl9YXAjTp5FI)9bo}YOskoMCA-o$*HHjD$mp-0&IXU2a6ehWze?Q?dD{#hml1*{^Z|y2>Gew9t?a0d97%Z zpVR&2uRK_RJoL@9+yJ^`PQXD5%+^aY_%+dlzVb=QPjo z3teK-yK5u*5BnSH)l+@U8~8#l*30(4WPay-`+{kBejfZbw%iPUfO*U**#h-*FXmrb zNcX|F(&d@F{T?Iit0}v`8GAPT2S;f~ecL+H`l5vg znfD-M^usMQduab?|F$ki&(nBM0QnyfAJ!K@cOb8)-^)hlhLHCUyq~RqkstY~TOfbo z5tZ8y?!tVZa?XYHJ^#1`>2=KfzS^PPGT_{q^DPsfAo9Q&S z;4c>5gHMp0i}PbGQ-J}HKU(R2G2;C!{K$+0cry(Bs2`+10uK7uPd>Ep!A16ieM9Px zw9bR|N5GLka*wrc&>yLNioXH}XO)A_y`bER8^ieJqJi<8cJ3uJu4mHohw)f=l{xQ{ zIscLwhkbs<(i^RH2fk%df3W{-+=ubMaxg%A&c6wGkMaJ?b1#|qjZfSwgqRNeOVcOL zrCR#k>A%B{bKlndVm`qqEj~?rfc%fzM%aH_?en$GnR&*(Ve64;Jst1$m`eDyz2A5v ztxs;73_3^r1~X1Xuz$7Uacr~b>%zF0*-LUS;a@-5Y2XJ)$8*~Dm0L?PzYk{lUs^td z{PHVHo)LW!D}|qd_PL_%{>}8fsq%vA80F=I)k#XR4|HIMX(s;%{sVt1&|gd5=m7yV zBaeOe))M$-2w$oGHVDSeIJb8l(G&JR2ogAdgU+|?+XwpT*-GQf#E~KX&FVkmxk=xG z{8V_go!W6@@GBsG?1`G!_Q3#-^Ha#{keexP>?}$%a>C|RE#S$IQUiE_b0l}5|8)Mo zrxVYcO7mBV|J(ccg_LjglbpepOMN{wZs=nRF>aE}mwZ9;IQSCsm52+gp?XE2fABw| zpNSTINqurL!AZ-vTKulciXY!rBrPYuGs?5qy6SeY56Yuesx`!T~DX+ zPin276`=W&wr#cQXVy`we#WZjs+`^ufd5O%BL0ErAEJGNO**=KmPYg0dh=KtzGIi4 z%6oT=BR;g59r(l=FX+FkrCC%S{n=Qne#Sr78klCm9c!a~0>6urwFC*i^tN)~(yL@dWwbKBVWO z+G@iv5b-gE4@o z40sEFqK%&!xL$u9s8GMAjhsM8rViz$`#9#DXw;p-K5B@4? z<7*w@32*lUZ|V2=Z2XmOCc6Oe;TKIDG&{`Y+ z{cA0}la4)M_dm7P!mA9PPu$zW-o*aj`iN#dkF~wF0sVr0VB_b}W197x-T` z@7~S|D-Zv>Ax>AcRt{dU^->ta*@9xS2p0xxf~@i%ohJs;oI4JCYL*r`(PrnL3 zVSj@GLIeGLJ)T1A?*T$IpN^kce+>=*_^)j|&8*Y*{-N>pbUby5hlv1zKQ_Kf{XG_5 z4S}~d-lXwXs=v$1TlNq9ewlRh6BfQ^@US(j!dlOxlPhd|?PE0mlceK1<`;bM=7~sf z4zruYkKL70$6Xe@bUclBfh9Ao`ki)p>kb>=q`*@aywn-9?z=4Z3Hk4M8sD}xTX-yW zK4snSJ-*CpC$nE1J=ARDZyGNl|KSN!>HSBaqaw~7*!Y=_kK5c@eJWjK@G?VJw13e1 z>yCEXba5*Nn|52}h z0@dGq)O!EX;{nzeUx7U0E(;b~Y5zk5@#`c%hiu^w;7! zB!?`F5`UgZyT8Eun^O+r7gJmNcTu_=@9{~}eXKL(!M}HXVe|1zF>aDy=sigkKLZ4| z`0o;wKU^eb-DBY$=m5`$oN};s6~Rkcmv$iq-`{1WtZa+L=U_h<(?S1I$~BAbx8A(W z=7SG>hv244$M3eDC#8D*7JT6D^;X(>kChVsfr21EECZcA4*kr=W9iV^4yxCC^I;1Q zrTW?Qyh&2$ENT~Y__TF@VvisMsAT1@+W7h4V|JQ)EIrR#E=PUL`>W~nHESJOcRX#; z*)toDw0tN*|4;9iM~^=Z`g%OjyW=sd-3(sz_SfPukb6>ZThEcUJ!6fB{3qA}!GAx& zJde^^w;i_W_t-PmxHIz4(L;x==e9obiFLoX!`=_JoR|Ryu2mjA_KDr@Bfq!WWvO1w z=P2;;pP;`Di1&CJc|*hRA_u4+z`qU;@b9GjfOp1F|Gomxz6!rS0rFw`2JCm6zXH#` z3ctX=L4MOK{}rA=ez5TCbi9K6yOHQOgLf>0H(!Bo8T`uNn`;ro*QtR27kKtn_(k~l zCg|^U{2IW!G@cFMR|c=rc!u?d1LO$k_v`WOtMF?gkk8^@Y5NfH%;FoT(ZgU7v3*zynXx10FK zfL!aZ2%!E+KDoo@lNtP#jy*&BGVFNv{`*=Qe_MCh?H)aT8lT*;6xFehX82>Ly&tf? zr-4tSv@IvZ2?O(ob7AOb6nv88U4#5avkT z>LKy#L>2Td+{QXvde`HDj6T&HSZ0@7dKBb?wEWh4ti~#b{~GE$#-DoIdVU5knE(B? zgb&PreVgT1P0hx-8_?UwR{+qRKIiU!pI-m4??~nQO|02EACi$T)AD3S&J=wQlN}2F zjfYs;Ptm*sehBDK-~ju8KfIaZHIE!%wEmD2`BW#)VZyHX2KmjvFR8zGZue{3kU!~9 zI3cK<;;q=X|1bO+1pNQy!vQP}@qp%Z*^7krgFaM$@W+SDQ1Jgu!#z(?yn{Zc!;io{ zYX%0@q046}ObefFHea{u;5<45O5#~wsk6eXb5QodE~%<*4jm7fBy4FUf?MIjB*{@jbL@>Ad;fSu9)PMZ+VX#ZZ4Ch!^D ze*^i``yZsYRSwbqJG6`7GZ;RC{ORy-lv>r!XuoKURSu`Fof4Tv@pn3W8}3t;(|J&R z`iW1Z;pr(u(f>yP918v&>^}4v3_le8|2sM~6#NbA&V(O|{-1FU4R3BZ6MiWA-wuX4 zRQty-ITL;;`u`@BfuZ2hob)uNTWlwUv&MM@I%r6FwpN%@UL$<6MiWAUkN#5DEP~6wBhfio9WL` z^iO@FDo#sm%hUwlxyion0e;bk_y_Q(zL_$EKCc3JDEPhir7PL|*@6fDoB=-%4;~W! zaT~tJ-d`+u;7_P|P}_8QF82SS;G6fRD-MDO{+t0H$NoPg{C*q02lPQ78b9#o4ERdy z|3ks=J#2v?#NSQ{9{6(xe2(uc;6JkAdX$s!SgSMO|DH88d^4wB;WKFbz@K2VRe192 zAl7%N{x|2^@OyEeKGc8U&l&KqVthlvvkDu&C))zI;A2Dd|9ffvaXS9>glzccLGTcz z&%mEs)BNLf_`Rdk{m1;!hx(rw0=^sj|4{w!nRF(+$yj4EJ#5ftHuirTeo!GUn<)&K zQmik1sQn}inaz|L^m#GOKThv|PlpXZnEpl$(f{ky`~%zIsUIu6oUy&Od}6gfnz2Ch zskPWA$YTq%?hn4TcZ=Qr;4<3()95@F`%# zfBCxh1B2k{S*OFdWB>m;_`&tTzt4atuhZAT4{Dy#UOV`CbngMk|Nm2X+W$|7FT?)- zKZU3L|8#glPi&!#-8TqPD4+i9I4z~N=tH#$E~6(<+2GIeLi>4{bTAyXITU{x3{RUa zeNMu099qEVGWufdzb^%EMJhvbDW`(;XLulUWv8otHi&>wG{Vc-mt~<#mUORMv1WB3 z9*TNH>lbz}Te3c|ZdH3k^M)37c6KjWx2}ENiYt}`;!)Ihd^jSspB6>y|I< zUeXx|EnK-`+3G+v5L#pg+Sko*8@-%;7+JAy#iEr<0-+_V*LPnYXrC9+SFLOBT(Wdc z_o{`s(Y|7pv3SkO_GMig+Aqe_maP27((#w9n0nsK_NnKsm=SHCVMr`@1v)6E&d}QH zy`ip!-3wO*LK`B+G^P)4pTgqpTJS;^jz=$KrmA1a#;XQ>Md%A<Alh$>Rz%e5bC9-YS?T^cb0HA}}t z;*k+-*rE|^cJ~+~ZjN9lcsv@Y>FykpOq@O1>-Dbee4KF?zqY2^j0Xer$2iQow$Y22 zxPpcC$YM51)fY3Snu~B{EW%Zit|E1us~X*nE@qsnj5;gUHhPLs{*8~O%8V|`QCCgg z$Va2!^blh)8?EdXsNc!PuqVZPi!WKfFtmP6XxYj&ix#fDh~);~mu8Oku3B@+5%%3k zFi^94{ff@eWrghOC^C!JFDzo?Vz>&iwF^5#E7ye96|-?h)C^8x;%v6JI~g18UAJU? zS7_Oi8aCTg)1?O+*e6~yKHA%iOkh#>!az%l)^3^@`Kk>oXR;AN9C>&`D8FV-zLJGI zlL6)&o(M3dnK{Ky*3qdOsJewMHb!|nS6s4@6-9TkhUDmPvq&-?tP<>I;VWa~jd&th z&2poyT)3T^B5^$!(3WbPN!7e+<@#zaO<^d~1H8Zp@TSD*0;g)6CB~E|BuO=Hk;e?v zi=2^l%U4RSaYj5I4D(KXw6}Z7ss+5-3|_}anr1QcC-36!8-%!o7nq|&IGMbfd!4@% zi;ZX!tvN+e)G0eM|(jpp>=PIx_I!PqF#^9?>tkZV6sw@DKW-JAnYWo zg6GFpGG4)>3s-mkj92TL%016I3?mjhz@53QA|BZ*oPQSQ=!SNY%&L~Gy^Q059c;LE zInR$?&e>($8)W&hAQO-Bg4m;c>u4_!B(!1m3WbeGj+KJus2?!?F!MICyvPPeeSFj- zk`dK{-nP47tI;vQu8Jqy4 zKdWBToycZ?V0;3r=`zQ-q^I2LC$v^2T;>%2g} zLS`#@&JLb$e9!6I!D}PIEj(lx7w{40WG>zB85`?ll*Ml6)QsP z{LJXq<14uuO)_bC>;QFLF|VHisjHZBvy1V^xTt5o9TJG6-4`;VOyozlNL8`6*)rJXs4`z+Oy_0M0-hIt zmV0!GhnXm2qq_81m=_yiUXTd$Qe3M@owa7Lf;AXA8c0$DVa}zgyet{!MYyfTk(aNQ z^Kl)FjT+AL%rGz0mSq(tkBOoPZtf5wz>F^4T<9y%ig`(*m`i^X<;h|${w|~n=9?r$ zos8vEior)BJg3J~gqx~ajD~tVH3l-piP5eGBRXHozRn%hjpXf8blrwUe{tl;<~xA* z$v9{w!emdJ<;UkcM(UBfMV%2@#&cuKxU%kMvP!3V;6GFc2NX>@e5s(^hh*# z4l9Y2v7%TR^Cej=t;4X$wa&=WRSSR2d4Sa>K#tOQ^dmBk)tEt-;8!tUD+4mrvZ{#j zDpsdc;H)DJgS2eq5My*G+D=_n>|>A2;)w#xKF9qMqTEtiO1ufGNah%-R`W9 zCl=?3^RbJKI^W|^jqiG#;Yj4W9+{`O1B6D`4mJ8+k84EqW4SSct6IbOo+lhNKo@zD zQOs+yg2?whP0`2=9?!mkA<1{}P$YICcQ|A<{(`~`bYTa{cAo#!ey=efM^@N*0^nk(GWWU$8QXFvVjYjM< z4__yXa-ej?=mG=wH5lpbLNO{>c zlnKX2&Eh2uY1ba%7g8{61II`1iR76xn3PZ{gy6Oo6$~ze`AuJ=VeaHL(co_+-%ef^BLw$L(q&9}G0VxPxQG?(TmbubqH&aaM5i-snooN{L8Xpj;~8&lM8De^)??K^ksAw3BaI}G z#<<&Q>M{K=2kIW?;ufhedatt-muliLch(W(@ZBOwD<#!LUV2OP|yIyBF8h_U-C!z`3vgq2OqS#dxk1P_UiFYwA=DXf- z9EgjN*p!rf*V`12KJS&z%2ILD%l{$sa~U2i@#&I9NKxZ4FP|$BA-(J6BRLUL)MVm$ zQ3hbXoC<_mRuHWxN`BXimhm8&eO~Oi1N+P5VJt8DS4l-(DL4Kd$gGz*xD?I_$Kymt zH3e!CbR;HAjq$_|uh88g>ZEuUgVB(RA32*MfEC$}rYJ^Q8`TV;ta+1DxhKynF5=i) z!o>yV{oJ?4<9?EPk2u8bOs=2{=_PQIH6Ew=0^zOzR32i34|kYq8ognU2ixU$MWVT9`n{mj13rf;%?_TdP2tn zOUHRJ7P`Q^+nLl8#$jxMl#Y}C0YT~ADZAp<=SxK@L6b_<^nk@`+Pz?Y!76eTlsPkX>Qb|FLiIzRf)x1cB zLyb5B4zY)qL~?v;G{@&Cisks+#n=IYW`!&@`_x!dK&a8C0EDe*eRV z6BW792C2k+qmn}okA2(c2*=EC`<(pNEb$rhatU}a3O77;*!8q;`?7MY`EQsQ&ApAq zbqTr^|4$#Jvk#o5@pmBDEDVHAh+tg0#8DDYFxSFBlNsCX6SrrpvBd$OcoFH6fRq1) zffhquaU&qMWS7TQ1o&7Xfe_e+s=aMlg#&Rn;c4#eqbk4f?g+&3*zVxO75Un zjR(EFDyn-WN6_04HGSCr$GdaF2_n5pL+?~@IhujXk04T=)=5H%%S`+;K+^N zo71S9@Az`#mq^1+C(BLjQ5Yu`ap{2ph6 zj=iZMc9=`p#uE2AeNUExuY>q(k?|Hu%#dX4=!enr*4pf=-1Tt)C^=PN{8zahIaS56 zCrk0%=sUpOpK1tlO7CKEP-8=dnM{Bhp$`_CHZhVc9gIYcf(eGES=6Xi6esWURTx?U zD~Q|jp#fs#UkW~#cC>;>E6@|P7`3FsC^j(T9iJpt6iQL%9N{}V78^BL&C~wk8P%v8 z(;D*idN0(h>s4N@VP_~YlPuq8@SbgG@tYdN1J$bD;PnbDNdlz~esi`(5Q!ZDaZ`3v z0yp`*Y~X(xc6!l+64c^Tvq2-d=8C2Qv)iG@Q8FU&o-F1TDqCczmvT6vZ;`olZBbY= zX)}74DVLRs5%JwTD=#9+a=8vxBV^1hrL}S>ZUntGap=^ZAgu#-BG?ox#P;wlIRYf{ zmb1Y8R&8U<{8%nXfFwSaEy2P&sZl&9hxK@m+bM2yHPJ;;;`RJ+Li-p!$dU4qDUIP- ziMO2D=R`E!+Kzsv&?OZAc*iHC8I}1+3519SI}*{rWM^25Qp2DU zk67kx(jwTP01}zo$oW}0(xnW9$5+Oz{mCTyd&k!p2XXj0 zv7Fb2xzp{+YZZsxtTM9_q>r}U(rb0Npl0zvx7aiC93zoec(EZD)SB9TV6`#D zWK7%%u@2jm>pmyaUpecN*mFpzOWyA!mXN@%;{?mKg}PrgSwdcQoEQT3rtB;=k)0)^ zS1|VMtYYF!k0vmqlITyIB!&b;I(?-_x=fL=H(_5YOj&;aT@|yk z%5jUI@R5`jHyOX5TLsgyXm9wLR*+l=MWV?I#rOAexY2Av7yeioHsc?Ag!GF>i{iRzqzfiG2ZiHFja1cewtM zTa0BQ|Q%kr={>(tiwbzES?%HrGB(KL6s7=NmbMndn(SYP^Dvm%xi$CTZWkh1{@#5rrT)szg{lMuv#~U`0 zbMTyy19W(g#IFew9li%0&)OE>Gi39!oR93x`TYCiXfv-4__F`ELSv>`}R zdP9h{vYa2}Nl$x00+3Fp!gzQNG#V+?5Jj6R8ixVU;Vm(YygFT5nHP>GPDFS<*g+{c zk!y7xWIMlCd}pkHY|)kH<~=Ra1x)GAlZ4bFO{E}$oc(9>6WQ|5na_I;EUe?08aWHB zVoX7~c2%BWXF10)u|uoX^_v{3Im62xBgW+>ew(k%^0K^`pDWwJi;J@$*kbZ!jVq(E z>qbQ|rU$(I7ZQZmIZ$ly zIbPScvr4Vok*Aq2$6+x=A8_&aFu^|x8aQ7mN9sNwOCQgHV)AfyzInbvAjB`BvtJ^Z+N8Z3(9-{$cy5CBlkD?HE~Kl} zvgjYV+XG952kd@vUOv(ivzkQ4YOhn8TqZrHmg0TN#o+AT?|Q^vY8_*lS1+%NqcD3E zt2N+3Xo|$D^7!Opc_Fll{dLg5$B|rui5b8gfqH1erpY8g9)oaO1r)O=3 zj=N@xZ!_s6d#3U8vWwRDf(uCJ#M^>r!vwWrSi$Nyrz=02&O|kn`WVY+Ia9g&%Bj_c zK8-tnnj^n4`NyU?3LcO3qrF#qXSDH#m^q6tgQLw@v1%6oTUPtJ<*VA^wA-+FN&6*6 zJA60mSF|r$xn}Xj>tHaNVXOmNnGus%)}?LA1~y%}zfCAtvy`e|%~q%eeMRUiN?$R2 zh2z1CrR8dZ@+N&HsZiIhVHc@7eHr*telH{UK7AdRF9NoF(7}R6@_5G>*jA5soTWz- z$2-Dit~?5QY&aC-w--aPadZmR_|sxpUNt!nPL#6l$BKbq^=xD``tOT!^#?Cd>e?hRs_L9h zBlhYTM@=G`nCtwgV?^|uaQeiVqriBpL;3p*2+Dc|0~_jX1H1TgzVgKkS$b!V`0(r! zeMM(cT!Tp~Iv3`ccyunTaEaJlXQ3XM>wRc+l@V-`+4FK%eM*#%$GUq~amUQ@i+OMEYcgYc4s%3KO5r z0PBY@j7-ou?Va^RM{4wiISzh5Q(l;poi(#l-d9=@nci6&ODf5nW$m?b_>;1KHVXza ze$I&aGo{A5c=YFjzr0es74=FOM(I)!SzIhVzo-waodVR1o)Y<^M?VAB>Oq zYB_{Lha(h;*K+x>^UsMS+RE_u9vB`(ZHqG~U|`x3?um zsFgC$i!HT@#1AnmM!mBve%J*P{8pvszFBJYmt9#^(RkiPlHZyA@T?MJHiHTmXUg9i z;Lg+!Ug!-qj*J<1S9mV3&f2WGzQNf$99j@-Wu9r|*S!dMmtci(0;l)=i}S+b`;F4!Wd*U3&QkqTp~gQIuId=?$kI=udC7_n-vJ?p4{I>B?Ht7})uvOY~1i*T#(>xxz-Q)4_+?KG6~vdYz2jKVImr zizBZ__Xwx__(cA?FS?WpJfBYU>&q9rRxd5oTZH(VTxgX0yuVqA{Uu8{#EeVS$d6@N zO!k$UUyx5zovdWhP_u-9hn}YN@FeX2i6fJoJexa?Ou}t0y$lJ=E0o7&UToGtH0Qcmid*{88m~DU*!-N5 zCBM4C^Ku^Vb^E(#Zs31*dv8;uo?+7U8+Q`+rC&{yT!&3dTb7~XOw>bph z={V`8YsneR^So11=D4@fk3B&VUK!uu`)$5fC@VM|syR?+$HH+nmQ6>X)pY zbJYm_WBq%J7Iuf$UA}s8KwANqmGG>;Dp%jN0WlshsI3*R&yw!dr1?y&mR&Cwv8Qs~ z(}2&@gj5!rh5(4@vTp3AeX=yKuq-k*SNtZo(Tx3iws<8X67ZI53|D@$d`_JSu_4#Y z$#UJm5p45hmWT)i!Cc}NnO|c#4@#2bJaMm36W0m|r`uR{QhSbhewkaB)VHxvQbVYN z`KdckKVIqukKLUMchf^K7tIyF8wag{ypR}e<18jdL}xKx&T^g=x%$&$q-m^YRWbNp zSO-%QOp^W!WwGNtKO%(Z1DF7pl}1*TO5Ydo;8(FceRV0UE1*C_+bN{-g|W=SSnd+l$>^j0ks5yVzwEG=k} z&%HRTK{8`}Y9PF>?IX7%FR3Q0wYHDX%5G(RS7pJv?_5(GkFK7~tIO2r0_K!|GCtQ> zz?7{5TxP2XYsMpp6#9}$rQ`TSCb#4?7?BfP*;fW%-nKG$fv#OFvhG;txN31}{J1P_ zTf&zu$$rC-4&}j)Ga}{_`EkV0m~Qsmn(Q?f%O^^F)0q?cldQ7gm*pp>Glx?O#YX+O z3=Zr1MXo=JkQV)wk+r>|MvH#0g8y}c?@Q*st&H!3Nsc@3R`LvpkMUm`P^Y6Z&kiqJ z^^8%aw6U|4QMUSsFG#>Ii>YdY^03)YQQiaDqUusvRMmqT;?~NO<$5>~E!SZ$xcMd}&ox%LRujq^icec+Unsvp zdOn}1MiS{eQ8j%L`%{xuSGk1blLKs`vVbo7aWO%I$2GPFS0N4ZV1lmP5OLpSlho)X zVWD)P07{NBl~3~iA`Fq{7v#=L{xU2e)w~gw>rFFn1NSMa_QD-P8p*s93UiJ9y~@lI z*daHRIGwH=lo!IXl)Fl~q?BBVmxr?;!S;ubbDEr%tN9N{!upC)s8^2lK?N!knkW&MU*=WbF11 z8G8V^@447@XU>!1BF*%}je7h}*t=rK@4Uch*UcxxFhtIjCMfrZnVQ%qP4aFGi^lxC z$V~+}C0!bwE!+_KkMP1sPoZbRgnaY$J4=k7aD!>gVy+ShAP{{A=N6sm#hCd8)^x;j8vgoT}sn}5Chc+?Sw-Ow--JgW3 zqsdVQ`#H-BM6jh@QJj@(#N(Su^Vn*j%ZaT+p_p-xADe%tD$Bk8oTU}8xkZ@6v9v@OD86PI)EZVFl7jAZz8;dq)ogY$-^Fw7}-8?@8 z*(`Z}D9gGox2_i!Xg}HtQ{;VjID#f*&jMrJH%FTAk`+Yu)q(GQBeX%fcAevX?tFGS z;_g<=^6@uB_N0(_eYtn%mDNczkGU_^OO5d%HMtl5(Vv%%Y6@YeUj!x-z+u$(Q*dsMm*>bPWFM+Rz9zR}=C{ z=ONTj0_&;@<>tr7j#2h)Mnum+rc{M$p!HORN{p%yJP{*Z-raLmeRU%XYkEb<>w0Su zN)HKfv6{7XgIpYPAGo(5`iaAo@+{gq`Ph z;Vd*>uJp8nobJxKUAQjo@}B3A)_1AV+XcL(=Q&*Yv8xkm^d7CO4kCf(d8bjjQWP5J zjYMA#H#s0DRAxl4J2f~p@6^1?;ikI!SmJP#-#FX^6Yt?BXQ5UjN{k2Zvr6Le6YyW$ zxmJyDmYKwtD<^oCvRQUXd(O)K+3;Xo+jMQtW?8IUl-JQx+<3hSq1eje9Ox^Hb6k0m#W~7(O)}Qxye4-zX`b7ppqtir;@qYj z#lJlUiyU~+N1DeO&d7Ij@AGg2VeCCoOkiR-R`h@{V~ zbwhFAbdC5$qx?Gxo5{(8=l{(EvNocWV&95M19-lK5P& zG1~u4G$ls;IS9+`6h!~d|6W5Rb^u`y?a?W$%>1MVPL1oz&UKYGtFc!K1sptot#Q?v zW_(#*?3+7$pE1|t1r3RKLl6+gXD~2+SOdSO{@2>!;#lonh`DZGcqu}x<+^M2Z!DdB zkv?9jtx&FFDJ4`Zl+A3hiZIYF)u68keMRXjhA(VNebO2=L3xwDl2oW`y=<+j)0crS zCFod-O^L2TjxYGOz2^CVX{hz8=ih%Usq+r8d0*y?HzJD{E?&N*eOdRy z&h~W+S9Psiv3glMEJyRrwny2eg~p@oFALYNU)_G`k`>FAub&@nTNB{J9#B?#7b`c< z6Uwb@iE_t0$x6S4{P*}$<)L(bM>_psI&E&?N-6HQBP~ICO#{&VR<=}mpaC3!F5biw z%J?$nKLPPgAIx(UYI%)P*GM(_$9WQd_@!9;dA%CR zeby6Bpd`E8ENxVNGw)_KQVOn73KPdt_dtW>^LmU>2f{DZHyz!<9ez!y57QCGzfOMEamx zy7ulomBNJP2{|?!Z@JCtlnTsJ9*iIe*nQPJ)jT*4+JExkJZv)YgY$ADZJPJ*vm3Ms ztSYZc^)Y?NiwaeV=?C$km~n8P8ym{Od5&5=QZJMn8Z~L-bmhjzEIwhnI|c>erp94s z=~^t+tZGxJ1MzP+UNds}%EtOeg!oKg?*^3O;g=|XszXpk{CD$MBxbh#$5^s+$c6KZ zwF%+;`1?CNYg6o>s&d8f4N7?}oVthR^SZe4d!`(kzX+SfyYmrIVIG{Hr$!FW=e3%7 zYg3c1ZPY?q;(CX2a6SSL69?yCi0`dURjRqQDO<&rN>?3qdh*pDhcvUl9rgp*CbD05 zj;U+Vwf|^>k*)&FJ8@UjhPrrkH%R@irXt6lY06zqWzMx0fg9*1B)HSJi(# zA3o#wujh9=b~eGM+VpSAFXzuIiR^4bu*c3O7{GC<7DVRHM*f{wWHs?0^NXMgH0f&M z4dHO*jKr4oIz7XDef~Pf%S~$R<)%r>Pv);Hjl7I!#9yXopiFfnW-9yU%aFHORC#Is z2o?7s0?yL1v+snB@I^c$p}jbN9mLfa=PSQ$GL;wRUtS*hZ4=T8KhjV__lTTBcd5n; zR2Sdt93M5Q$&Z@i%5&7-bMwQ==v~8nqGge{FsMfEpv&3b7n@b%Z3P>R{2Bl`m5Hh#~wu<;?k3Soue4 zSVL0#ER`GiER}O4g-EZ%sfo=MHC>Y&o|}J_7nwGPd6+ut5A!8uv~-EGCMA@I*yV~p zC3r$to|2$m=G=(wdQLb2hl?a%0Yyhe42=0)o0x#>lDSTovs5*Yl|Tnry*y_8t)&ee1~mNCK&K(O)$*1M9*PJZoC zqZ5n+^z|i_lL^M;0g#)m!t?3-N&13;>EbCa$Osb*6&u90z(q6l+g5sqvYZ$LRmV-j~2fRh{kMW%eu=LKet&AtVqW zz}$6a!Z1mO1VI9UWIz^i2+3rDki=vHVG&S7*;K3xZgr_z#RWleLqs+awYGe%{i?QF zi(jj?+PYP(t?+-&ojX}>5(pK)zu)iI3pw}R=e*}V?^*79w)dO~hywIsfCLMOyyp&# z_eQw6fOr&W!nKwWPG!+p(<3(^x;dg5*j%kM!_5)f(&afKw&ctahw*qig2#c^TG$xI zmbD$!KQ*yrh6nvj63bjM4RwF6$n&1LB64M}_!My#^RA1ov!m}ZPsHk*Q*)(xVt$Ut ziS}roC}(=!GIR40Or9@hr@=|YA9=BItxOvz%-q}TaxCZ16B+JaYNey?8jKXaypG#= z9nD=rufjx8*|ji7ypLMtc;FglOuXr-?cDX-X)egN$L=RHk`L7I^7XO}E6vx-Sf7&c zob6UFGjZ3;_vOgJ?Pfg4gFr&art#wS^4-z5nFOjM4tfmd=v+!d^yw6j7uD-!}*c_*O7G=Y^sa$)`Qj*ao-s3 z9{HdK!-&E2JZk$>oJ9PQ7wDaxGE~q``8wh}K%DKw*-4z~@z)&XGMLq=>H!LC z34cIKgeUB6F|8ggIuDQ?3_i39;)^urW>#{~yhvle5WYw}aI5ImG0d@KNq7l53xRm~ z0JS=14Q-8?2z5mO6=6mT;`3OzYF@*usU8d(q*M<~XcassdH!nSQa$(OC_&nR2YC=s z9)HuYN6K?|^lc^@B~Kl3(3L(%=TZ`)Pp2630jZu=LgOq@W$kN8^kOvr#Abc#lgie$3!5|CP5ysB(GD>WI1|Bhm^#x;PuW<2@>!^mu z!_~hSBw;OC>PV0aU|Wyitt7aB-f~B9
    =G?5}kaNrj97R;~EBV-jHPO1mfUaJ1( zq52&%=D4%LUG;tNalXc_es$*4+AOzee0 z3OOxKxZpVr&5O&EgoHWqhoR9XF)P!>E9ex@PtiAb zzQ|&k-noHZ?Ko{_=geSlzjC}F9AnX|T4!B7MiCEVBjWA~?l>WnOPEzL7S{BW-RwoP zB9MxZF&COm50oVdCj;QmZACEEd6=?)dj&7Nj9{XZ1OtPF6muzq*O7i3;`r5LF-%5) z8|61+pi!#K|NM zPn@B|8AhBe;tVHFHgR%@GvZ_muZ|>6E^$T?XEbrf5Xb8;<~?~Q+j#YZ#{@b)j(d<= z&K3AE<`GVo@~U`pHLps z|83&ngLW01#0OkRf|U8U@vbue>FLDXOB~ofxy$@kbhFF+Lop3e<{viRrIh(wG&hTP zU2t+SuTDAnIhJeu?|^KUe;DtYb}|XwEB6v!o8>=<5yULN-E@lLE;^Y=zd-|Yn3m9v5SBFNpIlFvqMe+7azO^3 zFcczR^7U9Ro8>>sdlsJL(D}rN45q=8CNzC8cneba@c^)CC1C~LvnU!?!e@-Fm085| zFf{!vzr}P517s9s!gpMA5)Fim?-)&1&B;}o%y>YGXZep{I1X|wH{mOGMZV&}q_^j( zI~gVf>6nQ}8hO_ek~>RJKFzBY{z?c};kWbdWhV#l9ys7q;h&J>l7lQ);g_+8mv=Xx z#LPp5KRXRhBK~++^T`?LtNOR0w$1im2X~JD0k{|Wx5J(5-wAh~|3SF({SU!ipt%=o z$U^P;63q>2l#BcigQn8|DBQ*VN8naz&(;22cn)dLHQIBn=GJLQz4p9BbC>#Ok?-m? zcFh=7K0@}7mEkKAkKlv9&AKLCy^0>JT(e`m2tRZ|V^+u9bcq*&DF%LjvX1xTpc0-C4BZ;z%+M~g*m)P!!-W4Ti2wy zNk~^l4VDZ@lywOCQ7pYM&kX(&;V(Nk(raRpET-heNy@C+_#+fY!g(spY;h<@a!p=H zphuF|#9;4J%}5eX*21YIT!oogFvk3C4Ft#ZAp(vOoWEd4$(|c)xnF?0;+)E?j`=z4 zhsT&^Vh;s$4Yc7^^kp$|EYmc~Gz*w`6qtOJx-51RZ5y??Kv3{Pb4wxW^8@ICxn@G7 z<<1HAvUK4n$sH_!Dks2Sbs}8MTr<6YBq1ZOG`9>PuNcT~AITC(%N~S`os~p45Gr2P zLaKVmw<00&N^?yJX&=O{$}#<66z_gKfL#j8Z6>QGIgFFZ* z>{kQRP&u9m+&#$k1YVyA%tVnq5vbGLW^N;u**O`=OJrJ*Oa{V8iMkXMi1)J1yoMdKVO&f_#i7<~b$Sc+OTqW4^G4+Z%v=)cb5LdEl0t>Tri;m-9Nxg{RdTqpt24 z^d}rTNd`uBs_9O@uoy}Z=lQ_6yf_8?*_+_M1~f=SN?Fi+=v7W-J%Pf1j68WM8@HBu2@o0pEQir6$&gr3qqq*Rgi1z_#v z+!si2pPEMuspNH(bx(Pc0z;~O0V}WU3t(uoFJR|g`$%r6d9c|C&bKTR>Cp;K!F1oQ3ak+kJFSt#>Iw6P(`V+9(!jwSHx69EnP-||j`e2> zrW5azoL!9hnkNF7Ye21umxO;F6Kn<0xt3+R^A}&v5Io84v8W zcPHEd>2OTL@>9|NCE;`e+v9-@cg6|sW!hr;YC`-g^hl38mEv;CNMn;*mN8D2bIZ8h zfm+Nw?naq4=9s?4;wO}W_*lTg{NoqQ#4_&f8MLVu`_xtR%&;X*+7)mi8+HW-j#8Cf zf#IsW-ELo(J=UYlv)fP4V(+rW?ENM0+!f%MOY9iA>>fnW@Fjyj=lo$=elZ z#YD=kfMt}^Ex%$(E3#hmN{X=C%1eP`=t=KXMojm`@6Qf z5HIf;5W|c2R#EKFYvR7*=n50}HMha>N*%p@vF!vl=N`nor^j=UyON&fuvFFC%bCHh z>=WtErU3iiFRc&1g*9N<+IXH0FjMhEOoQ2xZs*eh_L~a(i?PfNi&BJ@r(*^M*m!f_>Jc4$>>Ofoe!H6I((`Q$Pj@Gy4W1x)SdU} zVXO_ZX-S5amPa>+D&z!gNVm^==FcVO(P>;jMkh`xL51keG4?Dj(%)EnZkLvuP^# z_(mEPPli`p{C){=QE3!&E$?LUa2-|4un4@MQDp%+Wb5M>jC7xoN_bOsX`O%XKWbi&qcIyOW~! zQ5@~sk?O&b3%_}qyJf67Nqm*zPOC2S!y7HFuF(xmvo^y-IaRG2YS}mx9#-CYS{+H8 z%tX)Y;~0mW?D-!S9_TTDW~=2+sT4k!o6WT`->R^_yP6p#Vcs~+b30?-Xim=Cu_aMm zyv6b2j6u%7G4_9|qunmp&%doQb1#FL!e(N(fvHd&cyekB<6DpUe_r}jHA6{$q7LQ7 zPgK^Ny^^_jz2lc^vYdcrGPFgYFT_{)H`Nr6AHJ>F#J=%jvh#f=?o%*#ucxe^sEIBc z&HW9cyZ#0)?!T$1+#ny}aq|rK;8un~7eMiBTtzcw<{`?sOy;Ug99hjhyNVqof3Y&& zwhB#ry=nm9Z}6C98dxh(X!KSEDzgEL5;uEZ>Bw@U~6X90@~vu5A+*Hp82& z<0RKq6BpjZrAn=vl0EyS$qnIgZTlsvZ7?tHm#k@?)=i0kdE8-}MROltZCa4iB99S< ztS}wC%5;swVWR@?NraJSHTuj8|8_lgHkxi%;ymQYszQ8f8%^{a96A!> zUZPFdMSa!AY`1);!aYn|u%a=aPIo^{n}(*26L4xtlK3z^Msz+*o5PsP&Htl75i<6; zkSG$h^m5UlYY!cNScbh5OQP8E03*kd!3$`p3uKow;t;j8XMKk14M z=w`ieirYy8cl~$+jVCL6l#WYT#h$=ykhlTX0$|Hrdj&St@ovONzwe35zJh(R0P}6a zF4~rZ%?WAg;$yUFH}5iRc$p7m&gsQD2hY$k zVOydYwqcuazPirDEgHI!BW|a zYgzX5LA>(fT82ByIC#&y)M^y$G}bm6hxM!;?0k9+rm3f~_H4=dPb?(li>wVtfrzZ> zpLSmSr=6nX^0|q0JHK$^I);0fUOir3Sbr%Q(fngV40p*o?3{kaw31(VyVss5eP-jO z+r1tWmmPPzH`giVX2*R!iFackpE-?UKHo#FOR!X>n#|ldN!UCcRRZQe$sYTqgQT_z zmW*uE>6lFECNDOHHanMC>xN=C_$X+R8@#p@&z~;Lb~}G}dyMnqWarx9gWU4Pma;6< zZaXhs>zz>`i0$fyS=UPCn@guv8;j4_t$X`VWsX~ubI7Y zB5pLd!=Z5Zc~`|BVz~RgOXJZ&`6+Jbpge#ZIw%j~_OgT9iQBIn+(X_a@h9xu!??X^ z=Nv~!Q))^6u^<93Ce+vB||!?VY0+MO5_oIy!>+BYh0p{}J^WFix^t?BQxpHZu zcxz4UmEL6Kd2g!X!G6FaOEYm#c0KO}r|I>i)lGH%$OGmFyr)k;h2h)*FDl>xFZP@& zdE$hB$XErZ@>uJ|aWOA=c_p{WYI@e4DIcdz854Q=1@BSl*%!QaZ6^aB62ajGZ!C_P zi8Of93*L2Le$eYgm-V36j!)?zA`lLGS#0|_=5^SrU=wo)(zw4wTjScQaEH7HL3_xH zV|nBkuxNOUi96(tfz{Upm6r~AQRvRYUIEiXhrviX?6vapVJ|uDm zGU_y-1N}>0E9UNwdvSbsk}E)QuXwX#7kF_iWqet&Zki=KLR%-0vbSHG_dK>q1kT3FbzwnmEW-_=*@hg3}sUSZAazXtsjsGt-{y3b}y^s5EFFAbkH(qQ={|&gisdV{C+5$uTx4dL3=q+y$LEeH4 zs{1YPg&_ScNaf!~NY&?Qa~o`#ydx3Y??H+apg=0&`9gXl9eXWIA8+xY5Ai%2JJ4QZ ztSQF9XPuVJscrl-A8kNdNuba zw0*MhCwdL{pf3i~a+)jr6EPh!t>Sk2967RBfjQ%cd>b=85BaQR>E;h{3YPSUFL5l6 z1oNdx2_wvJ$HDS`myhCBi*D{Q-xZm{V?JwAf@K`0Q}_6$7l^2=K20y~@tuFYxW~t4 zNPA#E?cC$j^y0DI(2KT_nqEA<+rrK-;2!tU+%3?HkNZ}|A7*e{8vlqxyFC@T9gN&w ziQHZ%wlh+UvU`J{=0t zcYnN>+}`7by}r|)G`J}Ol%l#98OA~>?)9yVFC2*58jWInng|LJI#Jo{dyGK%n4Nol zXsq}8Fni|KT(ED(gx_9YmWIsbrMl)dPpQ0`VMNO zIFU%wb>Lo~>F}e2)sf#~s%ft;On6{-g?bxIwUIE)gYNZNu<2``kJ*Ec+a3mTVj`WW z!D1{7YZ4+;Ex9fsw`hb*byyJz>~>XM5$^RZ&j3r3jg}NoF_R{gjV5tm5sNT(3D_s* z>7oOpGQe&hY{|SO(yPe~`%~lqlpk=cD(&^<=O`*P;9j4EbDf|8K~6Z^=tpJlNuMQM zc+!`IKTNjzM#N6`0V`$wS=>{;iDg)~bwBOHevzkrL|?*9$72%y2s9$@&>HS(-?Z4F z3~o|BAd&33ra@NZCbnWP71wko?h*n2&#ycC;5#T#@WXkSj99_Cnjp1amr;E!qpG7bpS%O+&U9 zo_`J5UdU!bKHk$_h*7lQ`&0^TcNtOzni6l8m$rXbr3H3dm-P>|#X z1xapDkmLph*bly3^j&4Ix>+AZsN1CIBlf85W0kDMrEQ0(PKp(P&`~+GeU^VH3J#< za2Df|_QEjXfOBgUq8yPZE02_t1DMtpIv#me=O`1^==7459|b80a^234{3F~7RAYM~ zhFOZP8rut(rxPLE$U*Z+$qD5sLOB9DqG4EdcamMSU9@9Vh;*o6EyEE@te!fgA1ADk zyw*bJ=co#_V|$@APL-h@K~7j>9H~HR8;Lmyf4unBb0cHzh1}QAO=l)$asPPESKw5{ zm4%Zr0(=hrYWFHayb8ptK)kAOMBJtLRhj*zu~`g9@0a7p__;NOE?!(yNHi(jbUY^E zk3b{hTGzu{)1+$()%bCQ*c&01NDx9iQVc$s#&P>cBC^X0gYgsm+-0DhNN%9L42%VJ zEumct#%sZNZDB6r8oYlILaZw+A*AaG>H0$GQ8?*%O2QvOh#p;ENc8CXLZV047s9ra z5aQiHD5#f%P!cZ(p@e5mk3Me8bWilhU0#?DJPCh1{)is^8Yc;jaeq_54i+9;!+lfW zCStS}DCG7RpQcAA`gzYpzct0PAM*$3YCdJdDO*%jwZ64wI?Zx zEp~SS>2DH6s^mu$sp5|GH(|m|6zN2N40TkF zLY_*^VGhPQ8L-5`2B!8yVMLXBN(j1fV&KAusa3FPF`K^ebE^;0+`|r)TXSf2{4pHC za;Sj}yVv4|VfQ-RFzjBB+uIKAa@^i`a914ChTR)*`-`2s65tnhZX<4OcJ8V}%i`DB zxvO#8Z09!Nc7vS@`mfFugMRZ{iDgny7PcPxaXgMN+ji)u@ud!~(*GVWRr=A_gqtFL zg!AmKIYe=B_I&QzL#yKF4&gCn9ukr=ERR1pKCyn27kv=<`mB%(uqJ&8~r1to|Edwx?q!Xrb4R{620 zCF?4*%D2NQ|lks?5YN^PS_Fk-iaU@0CbP#jc1hSIvyIF}!UlGp%3 z3D26Q-EYhE~VH<9C#B8A-kyHHcL zxy8^Sql>6%VINpB5(JkH?>lyK7`SlPy#KRFQRBGzs>m>7F^ zQQSl}R(2#yKVmuj)C>AN#Za<;reT2nXZkYESxl^rG}rw{F!a1+TFnW?F_uc0#tFq6 zGex1;IxW#sj^$;cI4M&Sifv2Bm`{%4oqsw;aes1oII;MOOj#_pO;0f6fObJFc3_uo z7Cq3EKo?+#V_NYru8x|*nTpBg^fDZ0gp;R4;z-1iiKD=Q(jCAqfQ232-Nf+_P;g#K zE#xtmzmONwW4-MLnILF0A*lm$=B$Q0sln0$0Qj8+^p|7@>Kj z^}Y=D2xa=wgcPZP$rRRN=(NOK!PWWDrM$#~^?x7pi}8s{vJczj(>==>+n=$M2FE~| ze{;aOdBwAQ5XuAM#{Bi!Pnt9k5Vq*BK4RO~J%bFbNP3l49?Qlfpy5cFfOqPVYowD#cm zwU7JSHvsQ3be%-kx#Pd^SwHs;Q52j~f{pd@pJ1xy6W=IB7*&uhAX_n&dBs*NK;v}J z#^>2L6tAGD8Rl`SbE}E_%mPf?30?&2vdHQvP49jBR3u4c>1Qj|PP%r`!C4`H}s zt9#jT40B<>$N4`WOy!lSUNa+NY3v&!ocxWCN%15+kEgd^WD>=n%G_`NNt@g>1{!bw zll#V(hbr_9tjNhZbpBolKA!l+M@@!9E7jvLL%uE&ZTPcrz@NM@ZPndLr3}K zmE*akcsOD*C8-yca&I1L;hkX4BL7UAQM`}q3&)5ZOg&&q-{ZEP23OC&=v*a#9*6Ve zu8L>w9nZbN*_p6uv<#y)UVQ^6A|1o;S(H|cl$CVOHiP`Y(47=Lg2L%Ly!6X>6H_MO zd@YABca)4pQ^&&>j&url6y&M$f6;7{$KLwtSf?<8N|Ej=U`LrUoEbRzNxr85%Lfm6 zY}bL(2cGFZ3yI4q5=U^X?3rEMO7h!*KttZu$PqEJmy!tN_Wt!q) zFT(QbMHD+?j2!1-H!>M&n$u)sc<187KOV=hcSr+7hdt&apZQ~7On&J=`A3DeXP86* zKR!5{CBK*SD#czXP*4lA6Y>PlzPa?ZlN?RxOH6)+9u>MY{X%*%mF7J7vw`xu0{caY z={SRBj9a~D4Ex%R@kc2JziYr#88Ey>{$cN`@Y=Qn`B=VfBb5OY+;rIDVt#ZZj%cbT zbn$q)ASDSa(<$(!{yN6GkW5Un3ut;*dCVd+^ZuaWg6k5P9!%Ig!89h=>J*p}YK|<_ zP2ldIz)lFZj`^FHsZrSmtb9NZ(#)?D2MQ&L_6g+oi7XU-a6(Kwezm6^E1dYv5&3$| zh+ogrHnaI44n=2RPbN|5y|}!N0zq(P(@??zuNA*Um7T(kpv{hVEcEmk=3N{g*CM@0 zbYB+bc#*>4$S+b{Rv1U}xwlbR*J`6}j^!z^Mw4k1lN)dPC&O-_9Uoiqn-)06gBmo} z?K(gV1EAGLQe5RkD^^6M4T#ONCi-m0*0AeIL-ao7K$xBESW$ivCxRiO8ZR?Q*oHc~ zfwrcKrI^%_aU#VB4ijsRX)#kXC`S@y{`)O)v&+ZG!uytFaX&O={(NeJdEG$j=L@*yJ~m19Q858CJ3i5L5HYAX`>f*xY1UyJ1l&+C9;fm`eXd65 zWuTcdp?Q>@!+chHBEFhI01i-G#+{I==Tx`pmf>+9u(=*-&v4rlR(7g| z7w-KQr@4T^7)3fzj$NupapI2iio{%(yL~A4)wi^1+9k0&f2GJlIY5mS^1Z_`by_!e zwl&#;(;shAV;v)NBz53OI(8w~dVn_N;Uj4~0CMr~6ffSYx``7Y4!Vu!=%CYhtJFP`Or91^j7l zhP48(N%Eh=qh#$U(B(|rhhhAl36441Fx>X|pd?qloCYTNC7^i}=>w8V(0LGO>|KrK z(`hWafDjVGxV{Vz<=tO~*<{aME_zj$ARepH7Nes&%a;}m8mdVV`p20u0MfQuXr$9~etZGrKes%#!%9`2N{ zOXDii@v0r~>84YKj~V6#DweKehKcqD#Z1-NcV^|e#LO)A_>x3%K%Etbd;Y;0dmuYS zASY&JxyaF;o7#NiozkYZER29RwOIy=Z;YVk(A^L29`o0gqz8s=$o8Ambi@fuZgj!eSlFHYj2;OaB5JH> z!e5$MdPRITei8{v@Fo=}M7VEKF(IXHZL=pktE<2~0Zg~Hb=jM|wQYSG_?fP?aqmhv z{6zU(3$bUvNgcqvW@GLct0f0JCe$GrAmfF1TCAp2<#GGVdAc9Gi@X@&q5nmSc~aS)C{r#EWyg%&_)z- zUBq(XhPLHeUWwqAMObD^@l3o6i{>PA+_J_2u!V`ors$hTA#bAHBKHk#;WP!@He0yE z=>N>ej!!dwNC+>sLBfta$AR%oyeakq2}`z53-l?wWtD`?X(QVT-f6?O0`;9X4w7iD z@=hD}hB|-WX3cTpgmCWnZG&O{<@rOK73V+QX5#+P#!bNcdu<7dQ=LN2VWizwb(G@* z6~FHmGexz969&oqN2e4nP?MCKTLvl@s-|CBY>z9Zh3=TAlmsPNvCT$1GMi!JV`cWz zLe?(9VlJVO{g|@PNXVANL5281jzTjzbrH6+!j`e(gxwR^%P2}3!+p9Odrh35Eyuz2 zp1&@~QW#^;o5l|H##-{VBNj2yG;9Ta)kKhoufV`k;8zSIFT+;Y6Flio7&W6uF>J*! zrjas%KtDgxlgIph;224&NlR%v=wdH+nrd-;;-d^U89N^($0t5Yj!*nlcHAYgpIn_G zuVI*P^Wu`mYKJFMEmy_wTfwct%-M4*U|#j~inL_$(qii!KAhBF&cdW&AH@vYmagu@ zdH4HJ`AqFZKDVXa3Zv=vEb`pWn3LRpTEPu#HznW~TR2nuU`S}Qv{MrqhcMoDlNtRk zE2{d0UD~0DyUnm{|G>lUG21`zBq_g}K;2ToaTtg}G&6HfA@O>?2MUUFl) z_kmh_vS$*^?Ed0(n?A~j+vv{42IS1RAtgBAgcsi4z`~GujN+74MG0c>SE`&fWw7gK zsWAcVgu|CI?MuWFO7NEqc06py`G)&O&bB^)Zn69qr;S}WYy^2giy(<6Hjf%3iZXu5 zleCZral)s5Ft`J!9^%?FY4PO9FB5C8eLAzJZ2sev&6j<>ec=)ESTSY8M}w=UY>N zBFL1@B_9X}r(9F={*=uxeL#3@eq+kk*&kF-krFk6uvtT=Y&bf2K_`P;FnpsY6T}l? zGG%MY8y(M^4^6pd_8$?)pQdd7c@!nlw2^3JLo~U1G!WCd`Lcga+1&QYl+Ck0n6i2L zyQP~;-YDHX{Y2?CApt+{xUsJBRY|2G&3W?(;j*~dK#2HGQ!NeJG zQB(#bs&MZo5wlFN^RPkuvT#6 zjJ{D3R1cS^-|1E56x}|6PK&kLwxffE;i~pWykQ_oQAP0su3{#|S!RI%uf*R=a9jd$ zqW%q?{udJ**J9FPi~6$*X+lB5uOXhe{zjarKl7o+fcW1*a9lCOiTY3O^teK1p2m#i1%7=>yYL(J-Dxk_ z#^Q|g7-rle=CW}VGk4q~HkOV&M1k{mzfI@69rpkt{IoNzr6iwm^;KUQbpI?u6iw&C z&hXoGfBe{er~mATKgp+=dR{ZEA^q$&X50YgQbI}c^mU!j(_Q&YzmC1ubd7nNWrLNu zzmw;EI?qRY@RV8b3`cq9!k=7!Ax>2Gt~}rDOcu`3Nz-yy?p-T(4#{%bn?Pn^a7le72_uoK1<)fa!~S^Q_7#lQJ1 z{@c&uzxyoyC(h#k$yxjdM5fvLPXAf_XP(8s`7Hk1b^pHV;BMW&uR3@__m3xq7_Gxo zwd$eO@lVbYet^apzfaRO{ESHWp}?HF|Gi!PQOB8EJL3rGG)Y})B3(v;^m>slgNgd| zbZ?$4HMBRFVyUb?@ZmjhFA$pQ#7TVQ!W6HD>z@S`y{`q%bct8%(~$tU*p6On7E9oL(s zNPii8JUSjtXVkw!$74Z5t|;E9;{zga2fjhaqvNM2{Vh5^*g(HW$B7;#SCsx&IzCiG zJMph{JllXLBB3N7at*jk$MXz$K*ya1e2I?B2K;IrcN_4#bez;9az*9GkB4aT6l!QE z{s$c|HsGZ1PUN3rz%w;`I2FxjBwxG6>-a$Z9c();mKxFV9DVTX^xF3eFv_P~kGE7O zjN;8YKHY#{r{gnpJQ^=)kBR)zu~;-82pLO_)TujfM|67foslbA?>^G;MH%rK6qs_7-sRo!06ZppB(L z^{qya?_Jk*!}X9Z^o5B0&ANcmc!@XLQEPQ#+$wbX==Tn`+6BuOjn5jLPZTGfESBSq z*R2|TCmxOWc^!|wNAK_I{MYOJ^&4eDzl+F!xdG48@hfyZS|18^{7M~<%6W;7Z_@E- zylZrv%(0M5zfn7M{CWfaxQ_ow$D{H7T*vYA$=VggKh$xehsYI`$6}47>kb2+rQ>(& zcvMcGjz4I?=j-@x1D{KEe6IoDs^iZX@OyQ9zX5+v#}69tw{`re0XIWnQeN0bpWi*@{j0birzZ|FE)wd)QYf7gKT*YOVweBROVPYm?< zNeV4pUmNJXI{q&MzCy?Gv1(UT{=0SDWWZn5ahn1EyN<^i@C5XmNInlR;N?0_+9h)7 zVo{grc#;8sSjUqM`0F~JV!&DG7ZUGa15Wy31jjF(X;(D8vvho@0dLXqYy*Cqj_2xl zv>cz;ai;f2X%a=fu4$s z#7o+Laz*J&bbPJ>Z_)9K4frEEzQ}<8Nyn=V_y}~oNqXxH_+lMzFyMFUc(Vb2SI5H! zd>H0F3IBEjzF5cC81S7szTSYpt>af3@PX)n5dNDC_%t2gX237k@$CluX&t}OfPb#z zw-|6AIyfZWI}G?{9lzUvzop~%8E|t#gwKNp+^OS_81PCR-)+Ee)A79q{8u{ui~&zT zhl|L+-+;?He$ar=(ea}Oe65ZjH{iGF_|FXZ0UbYKz~9mF*9|z04iu5|4Ff(#$A4$Q zXY2U827HZ<|BnIxiH?6@z>n+rM+W?39sk6DCt~nObPtozO4fq-z|Ca%OO2^S@ z>ZS)i)Nzvm&%}U?#A`F)b9Fq{fN#_B0S5e_jwc%MZ*+XH0hd!E@unN_Mjan&z<23* zwgLZI$D{N0QN8C!N0_85&p`hp9d{b=pX<16z#T&(eB1`SRL8vre7%kr8t}b3UTnbs zs^if)@2LDT7Db5sr3U&{IzH2YAJg%G0Z+w372z}2fH&&+#RmMijxRFcWHN>DsW#wl z9j`Os%XPfLfIqF{WE@7WXnMca@vs3ez!oPGZ@U4n)A2P1{CXW@g_>~6yFFL-- zfG1~0;@xJzFVOMr20X0eHyZGJbo>?r{<@CeVZdP&tEKmD11{?LeFnT%#~(D{J9PXJ z1OAeZ?>6Ay>iAv*K6YqCo@WgBLLJ|4z;Dp;gE}7VU;R?Yj~no~VG;f(47gXv-!R}y zbo^Zdev6KOV89RS_$LPZ6CM9r$D{R-!o^Z!$gRi4Qh7ZvOBRuTc*jULnqF)zp~yIT z5A-g%pawfG7DpEwuM0IkojAHH+7;CKJN5S{|CKubmvsJk)vnug{LebBt*z>QEHzGl z>bxD%0`|lS1-7eo66#Ms2R+>dkIpx|r{TkK>R9B8$~kX%Bp=ABOXMQ+LF9T_$9FA` zyaysT%AOseH$yLy>*>giY6nj8xln(P@=wo+(C^pj$$SmD_UU*%zHf5PFaz*(p92voBq>o%tK5qag^306D9rHXJawGH~>GV-PxuYWZt%!$Q zQ9g%(6Fzeza0j1dqa*ZJA|7%@`6Qrfl6pniDw@yYs0j zjPQBJz~^S*B;NZjiFjyhq!-91?fh1Sc(c35O}_hc!(GIxie&kGGm;0p{0MZH}Zsc%eG#OvNJ6fAXdm$%MGmeMP8Cq%^+MflvW^GWLB$5MJ_?t+m= zmZ~%$kyn<|D|r{+i@=|JF~+M#!x59xvbRl_mJFkEA`Dwu9SBd;vA zs1Ka(1Mev(A!PBdw~Or?62DJ&ZFajNLq2aZXpUM`kO=mjPGF-gTO0Jd@}pMF$gnUkKwz> z>azZh?vQpJ=mUQZcyFckULW`u8lMpzvBv0611FJqd*b6XTH|O22mY@<@HjMGy~&f-2R^J1+}{U2zYn~+4}2-`td3D4)iJ`{3>?1;L#;PO z@=@;|QK$RhgYEylrYjNnJzd(7sAddlcxO8j#h(P8)gc9pl=X5v2As(AYh4~jO{e1AhWIk$;LM5+SjjCYJ^M z>E7fV2t2Dx|LXDxIki&5z51JOABMM=HN3Nb6{UAiA$$sZ&_AHz_{j7tN}oKH(4&wI z_(ly!A`JN7H5|+gcnDpqtSJe=No}%b?M(m`9G}To&DG-{-K8VLhwJ;G-`NLF zg_kwA*Mx&kT&~JWNqZnKW0l~t)?jt0rm1CFuzs0X)zmPqsty0ED{Df+xT@x6DmcGX z6obLW<-z*K`f#W^xU90VvKIH2F!7Tjel?Zh#r2JKI2orQ6a+UTl}qE$TDcs)eUr|d zxg0zz!liZDUhQ2rn~~lKR?EW2vinWn<;C`l^yH z7|yJVfKwZ*L+#ysq}hOK>KjAVQyZ&Vx_Y^%HUh5?uj~d+X{!xQYHAI41qGqJX~p!) zB~2~<`feWh_2cE04SrHWp&6mpaD8K~kwWaGC~9p*YI@QL%FN2DrJ-<9RaL09)yP07 zstUJNHk^S(P)d=$%9dWkdrBc;TgbpwP>SnYLRGy-DL9K7AZV!C7)T^(^sg)r)fjw4 zbw<<7mQYQor6p8dfdWJwENiI_8S_iMpEHJw$dbX4Qs2*v&qO5F(Y>A3OTHIcg989}L_EWMz)t)i-> zGF(+>h#k@`gJM};Ulp3!($pMk3D<|Z<+Uhv^`8pP^_2}sftCXupQ7j_aYlTKJGhlB zFHvghTUxsc(-$S=Jv%qS6AaEQF9`(Z1t-s*F)1*$Y(|jilc+`zWId~cP(`82WxZyz zXK87uvZf;+z_PZnEm+mwUP`o9lmYb5(qMB_Lw!|5E6A0LTN?zqA>0}iouV`76oaBt zioy$O6-$e$cN$k?8?;$-L|q8ZaHz4hsRin8h1}fKLX_Ro()w_?Aym>>U0>O#G3rWH z>RH-`y5wF~8Bu55`~;`wDX3@imBcgos&cQs?xk9!^}PtBd-W!dhlGMOy#nT(Y7R=HGw1sB*K@CT|IQrj)4#9qdhI-e1Ug0Eadc`& z(Y17XedCl+B^3A^=q`BrS*$j&_Zb0Ym~1`ShcsXPIb19FP+g8);cuR z{UH~ZPHjXdroNN7t2R{7=|GKePim?T1*cD)(c_!zPEa(rF9Q6`2w_oFUS2e>V(R>o z9(1Hho!Zb4s;z98+)}wLG&Ry6=qgMp(x#J4ML)P#3bD8{TuHL3S7}Spm+kQHMOZZ5 z9sXw$)l<-&cRbi{jFX`u5`b zT67J2d{14=!G)dxI@+y$#%TP$N+ntQl!`;GRW0?+=)!7xBB*6uxBUAzLh4!CBc$v+ zOPt+vxkq?O?KOO9<>F97Fycl(h<__=pK^<5jrS~)%fjUPhEQ2^7&@ypNbbwfyPQtG zziuymir;SreWtNC zQe5XM)mf4$c(e)=oL1jh9dY_hMlYm&i76^Zu4fTLk|F^nRkl>uW4Z>@Ai=KBeF}Bf zfPIN5=?U)|OV{VK2rM`&RyKw!+dCS~9vZEaqAw8yRp-$?uC5ISde7)hkxwKt9`5ER zsOS$>Hq6kPwZ3S?rDyh6mdyM&&?4^g}?m|Z9Y8v$EwC>EhQccB#Uu`Hdv(}x;)rD7uHlfr@ zgk9yD-AE?LB9ja=y77|4qUQSIP)%@7eQW*Vdab&4d5qT9a|)x^lI~&LXN>0z^+pmg zJgKj#54B8c!n|T_Fwj(sVOjLnhlFPi)4POc40A>a%R?(#FiU(+5}r9s?-HId%$X!~ z8$wQMZY!-^+0+&ec0To$Emk8 zsZK+hdP>j}IYSy{BR^;vG}Bt*x<2+Y_$uvh*xu41tA=!RjjU6IGe|Cpq&q?SmMzhK zMWizkaeFO?vxVtRQptVBc+Mc9;_MOff`+EWl?@lxha*B7UVD}K+`06Y3Pn9fSVqax71hHhRT~TbAhkB7o@kk^j&%OmS&gpJIQ&5 z9FWxVw#I0!3$$PzB2u1R9(ygz{)Fo-C9-(d_zco2k@@mqs6AAL#f%oIy0NiTQi~dD zo6t_wEsK=h^vZBuh!)Pj%^)g&uMA}ItP=DmAzdd@&M0tK!ZQli_tZ5HWO3$AnoujmXa7{3sCF&28sOdbXRWJR?@ zJ4A*PeUPhBaz8>}5Ly;l7VL|;k$U+Vd1H*%bdFLG2}lsaw^UFk%`PqqlEtFHl=70I z;tH}5s+n?VrWn<2%a*MSHdkVnIf!NTnOIQ`k_Qr&*n9L&q;Uq}vZm@bj6CYA$$F?3 z#2E|{tC7i-;Yn?<1eu~)f@nT!#AkM6OQ^QK6~k!xm!GYG#L;sxAE{ztuPN~B;=%dG#+ zNRdiGCD~kB+16N9H!YHcsuB;uO2{GwNs3%uzf3}6kZsFm3c{q>dK;`Qm*Yosf-71o zo6%V!Yto&W)sKuSpH^Q^A}A+0Q{f6$DB6D^SkdD@NJO%Lj*2M2_NQw^!TE~2dt|UW z66=?BDV1Pr9jTSoZ79j*VsIJ0$ok-7vS!)ZQDETTN0~*$A|(egdo8>QDTx0aMfk%N z5jjZeS3e5uT=OtmmdLFB7H5Sgi2quldQkhjm#HM5txRi|eII4_|6LN~hk_dZZxvA_ zOTJrC2qM1d?q#7Gv|jIZBl10A+#0N0QP~pe_}=h)tffnY5?wEz&(G|KE-Cri`llZM z%^dw9mbCbPrlilKq~B{v6FGl~B`yD-De3bl>GxXFM9v>#NxT2gl=OL&^m{F7BIkE4 zX^fkyyL7mN^Up^+t$r^>{X-Y^do62$k@C-D7(CyLHPW`rIjgKh5_K zU)b-p#Dz${e#atr_Lwo1m9x!WcU_Ud4C(hUJNiSC3Wca*$LGH_sof)|{y&!%KFA+n zX`N50&R1z^GJUtwlCWBZqXM+`NLV0;v6DNp=sOA9rWa!mIu<3ntTgwB9H}cMD+k#7 zPL`z!>nKy$QIjQP8IiVzVenC?z{XC3RGI6eTT7eQxzJr%s5w9H>r ziXd1bo*L}yXoO&Yjz*ZFt+h^4N|r2H0zvfc!W}!|6eKzrhBuKi>e;$-*}aKo#ZVnL{?{ z_lU?4xX<9OqUvfrurqjWVwcRo8f>@drC^U}wTR27k)_CKkrhTgU&Uy?MzUFmyz1pw zHhFSINgx;~@|TuiRg%;I?5wV;TZ$7V>KnRbcr>r{(!g;EeR~QK&x){$Ln3el#MH$l z?ofLb4z{Qbl`O6Z3*xfSvLy~O5@f*4X;o3C+IoJ4w#h8Q6sH3; zOs(o4JLHQK3{`Ctdqq=2Yst*XVJTAAZxaU~34L z{>!NE+{D`2vS`f+3B5cSfTMoz>BU2k2aPY;P%9eU!KQcu-HBk*J4sbQrJmW=T1QlN#B#MQVrvcw zzOzA1v?i63lQxLO3e}8llwp|C%x+vkHY}BgTH8=cJrk?1xoY~q_rwyL7uf3bNG~x} zaRs%QVcy*uA|L7hiVX_Mu=8mb0e;_YaB^;I`{yURevf?5Pjvkrcb%W;`aN%VexlRX zCNzV!vl(WbpZNMekN=(H(+1A{pY=XJ@twoO>-mZB9A?wbPlV?%opXL7)W_)kt)$=A z&>M+4H`e_hk?Cz@r$sBQUL!`Wx(_+c2xi_ek&QN|P0hWurrpLelVQHK})x6B&EEgs=gi!>XXehCshL$*ZtsvNZJ~`y5}QQg zIIOavw$a>ZElS=wT5bU;w~oc_W22cF_EwPc`5ZI9=L=gZaj!_t=YTsMcW32$h_^FZ zPG>3UoJaak5icjcj&$Vi+mgGKSHW}5dc{{3GIQ@1EZ!`qgCLPAX%H zv_lQ?&fp>^dG0yTyqVHV@fM4tEXDJTmXvJLuOq~h*-bRM`&rct|H zOq{=`EyOP*`(?!Cnc439_!@E6DcR3fP!?Z6_Ep3u5}!eQ2JwZ&+5S=D4P^f@@r#M? zCVmO=gK*zy$&~WdE-K&+d?FPqtr7oZI<1 z;@obp5a)J&pE%ohz>UA9eYl;=iOVyb-S*Hxob$PnIP2d_ob!2#INNU`&iM?Ior@`w z=jD{oM&g{$F5;};4*Olv;e2`$XZvE}oX_pVub_NZ5a)c>6KDO6#5tcG#M!XRb10t-+z?#WFXz*nIP2#T=X^?t zv;9=!oX;BKS5rQ35$AmVNu2dFaDRUN{W_dD+Ycko`P@c)F6HwGan5HQan^r}IJeJ7 z#M!NKDP(I^`b3QYOvwjnC z&gXXGZ2v5A&Ziq55hLw3pYq8i&iRZa&idnsb3PXmXZxFpb3WUMFQ9x5z>W3w_v=LB ztUsDK=Tk+T?dKBbd^QnZNcrp}&h3+hj=JQ@`U8n`KBI}V{dD4-&uZe=Qa;;=b3ZBr zkM5CtxF2;8xa&vt69ZGakv;dLjwH_gsG|e?*Z?0A;3o(8r~p4Bz{dx8d4Nw2@VWq> z9pG06`1}CBA;4RRbH8*oxH}%~XIEGBs^WRzZwL6hq`!!^pU;UeCcYMra>_Nj+;4vg z+~qR}adYvAC)u2?Bb~+ITIbCG|A2IOd;5Yow@)7&ScyLOd$Yk^o<+!W0l3cd!+_2S zc!ZSJKLcEJIL~V0oaZx{HUZJ${_tvWmuHDJi!%Gz8J*`maIN!3fd7MZIM4qO=RD6m z%ElFa?k7(GcX^IMo==18JbU-|>r4&sE5UVp-a?%7+?tcp=l=5!aF^#;7xXyC`anAFq{FDy&s~3Q~JkPae(H3x>=V{0Kb?yl8$G~--Yl(B7 zhaZ>H=l*skxXZH!c{YLTJl_xK6k2tq*3%JYkGHVu<4!1Kp&lTX3C+mC`;D?Mz>2P~y5$8P3T|^T_^jr1k z!Cjtnty%OAxX$zYBmFuz1^9#Dx;=kMobx<%bV{H5`8~j0o{N!ZBe>3Ub3iBORKNbo z;5yHU8%d;7Ic0R+-=sYh5&ozr#XH9_rk#u;ye?Xk`9DY_RPd+Xv z1$TMgjXWO$*Lijt>(`kW;C0}-J+CItdH!j9N}rDx-T`-cwjj^q33f*3`2e`IC+qAC z@D6x1na;B(an5tbxv4z)IHD2U<+&VrZUonPo;1m?voOH#1=o2#L7ek^3XeK;`K&@d z?-G|&hGRs~l&t~216$@9V+^^3_oiFWKE^4X&3zkE_Z-D23>-AL#uJy+%osRq2 zc;^Io73uT#KP{ktozj^L{pJ9_i}d;XyELHxqSDy|{Z|5f1L^bra8p443#HT0-nLKs z3OkkULDswDoxr92`FEikxNiTU;7(^d^iK}(k>H}wziX$1YyBBY=Ukk}Zw&A&!9}0@ z3G>0V{=-V=W#~U1;7@~#KKGYagKPcwz+L+k;r#x#0N)8N`rNPk3|#B?sx*eCIQ`|& z&kFEdaM9=f^FVN|KThfNl8wcba|67Z^tu0kK|ud{rL!3Nw+8s#q|e7wEdl*?N@oxB zUk&g#NuQ4^HwX0HBgb9)k2%n`PlqZy)%;*^X@5SxJ`7y9{|QQG74(Y(d=%;Pd4@9r z`n9C9+2rK@=bA?LH)N#v46>J8^K3Pvs{{O&0Kbd$ZzTN(iQhzgH}N}(53IHc>|-X9 zPe&gCJ;$jgY$j?8;@w3*s?j*ncQ6v2PSu~i|8HeK^ z%|FEPhUW9ozt_As`qP>}iT@($)*K9gm}lzpKm9|tFhed|YDmNm%+Y)T37-&>S;ci%y= zbmbvGf%ei7H!v2t$vu`Vj%QE*o@n?vOeHZwYiIFKXk=NrmuU)VBI`A76 z-wM7`@pf1N&k&cmYt2i_M&dHNAHH|{wo5Ed=JPA?PQ*n=UQ+Uii_TPhPbVlo_K1Xe zR7qSk-m^GrA}%@^eeHtfif;#hlDO!60-X)SMduFe2NF9T^I4d@<{oX{OgRACg)q@+ zhxL?AT$UALzcE(vUie|2L|k;bK<6^zqO%_RnTHf#g&peS#6>3yII`f^%S9IyJ&Bx!bGPUI+?^}*_x5oZj9o~M_E3Oxahd=ygNQ~ zaVKIwyoBs!S;-iyvy8atT#bBwOI+-CV!yph@l9n`=X2tsa|?8O;W)|Vvk?3FQpFp_ zS)DV9i_TK$)Dsu2KKNa@Tk(!~r27NJMdu0VJWpJ7p26?P4#it=o}x2;BwYTke^Nla zE3}PVR7SiT@p|Hi5pO2mo%kx^J&120emL=6#J@|tD~@9%|DMF%bz;JM5g$wTy@@vv zKZ1BO@gs?^BA!Wn1Mxn@cMuV#1-EtRG{y(FVd!*+ z3b%FWbVd!gb?kHo5Vv*gbjA|5b?mf;6}JkW-Z0?!H}`7&OPyYnGzJ?dyFRKrhV7=4S$?Y zh&cXvY53!EjaL40y?EUGD?Itj_2Y5Zu={^X+F$;EiW~nj#Ako`KPs;O z0*}A^KN#1)+2b$&U&i%c;_;XNSL6CG_xN9thX2zZfBAnpZv5vx{!MB4zwYsu|M%m@ z-{kR^>ks1kZ}IrcIl;L8yFC8#A4FXL#2-0EgSTI~z9g=HM`J66{pI?Xxc=^38~e-k zIdT2nxhVFR>xbg{7kKiQ>(S!+7kT{U`m4D9rSRwd@8C52%RK(_|7zU$6FmOMr{Q1W z@gI_g{}hkE{I?c2fA^m^x1X@M{*8#w+mD>v*Z!Gi8y8cYb~kgjxc~cpf%A!uL)Lvv zjF%V9;v~UL7yAd)?!NYOjhSScm#a??i=+N#Dw@_B&`gHa+IbyMf6g;zqN)9h;lJ9- zB;Q?p2MW!MXttM=r&OH3Ea#%kqLF5y_-o#T_*pm)>ohXm45o6qt}lKIe!!Lcm*UI_ z_^*Wji7KAX^ZtPUcK8q2i~sN7uiL+eBxXuyJ4wFl_J7eV@ohi52V2Dgc*^#ZV_weR zZ9f|g^R-_#;`fBVn}{u!kIe_)`a1>wYn6lg7-bmX^WT4#jdBPFGxIKghdT%Sd%(Z= zJFVB4_RoaBZvP4MteuMP&pzHNI)7<@`ODV%J1+awgALd5>k&V+{+&)? zAbv68^ZGCJ#2+f$9;_vPIb92Cd5&Vanm|To`t0LDwB-*d_YKrPT`O9C`uHLTw zUJJy38Sy8#iuUFG%u9jzOAx;T{<8j#H!mH(&g3n1((QNLHMT+*!QV~9mdi9_=Bs}+ zph`%&u{P}_SrJne534(6__9l5>{rfG5SN{y-@2h`>Cw{Rfep4X+12N-Y8i?QM zi9f;<|KEZ5{TIfo|EGcYH+tfic;Yvh>ppz#|6t7cb;jSM5Ea{W&73U{dE!f7MX!Ik zzmG3|_O1*oszc~;;Z*jc(Zw$oG^u#ao#Ghyq z^3{Ju%=lxCzpws7J@J3&iT_q0e*Sgw>ia{GSHm7hG@azqkI=`hb5C{H1Tm#j2*r0CT;Iul}QB#;*_fpAG*msZfbI z9e+3c_fZR`!e4UZqCcN$8c3M-pBB^qQTVr0@$2F5wilO!{=Cfi`__MxXZ_2iXwKhl ze{Unc>u0+-jfkJ?0H&_~Zu%blm05H@{55xej_-!Qs($nk{OiqfE-rt^J~7J^^>+Z% z+fo15;ZS3Em3cXTSO3JiQ@--C$DKB#f}`!iS~HPzO4{FS{$D!w5s%U}*8k^c%BNJ( a^*_`svU}>3jm%9MZs`H;1P&H!|Nj9EZ5KuW literal 0 HcmV?d00001 diff --git a/detr_tf/custom_ops/ms_deform_attn/test.py b/detr_tf/custom_ops/ms_deform_attn/test.py new file mode 100644 index 00000000..80eadbc1 --- /dev/null +++ b/detr_tf/custom_ops/ms_deform_attn/test.py @@ -0,0 +1,206 @@ + +"""Cuda op Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import numpy as np +import tensorflow as tf + +import math + +from surroundnet.custom_ops.ms_deform_attn import ms_deform_im2col + +from surroundnet.custom_ops.ms_deform_attn.ms_deform_attn import MSDeformAttnFunction + + + +import torch +import torch.nn.functional as F + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + + #print(sampling_locations[:, :, :, 1, :, :]) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=True) + # N_*M_, Lq_, P_, D_ + #tmp = sampling_value_l_.permute(0, 2, 3, 1) + #print( f"sampling_value_l_{lid_}", tmp ) + # tmp = value_l_.permute(0, 2, 3, 1) + # print("value", tmp) + + # tmp = sampling_value_l_.permute(0, 2, 3, 1) + # print("sampled_values", tmp) + + # print("sampling_grid_l_", sampling_grid_l_) + # exit() + + sampling_value_list.append(sampling_value_l_) + #exit() + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + # (N_*M_, D_, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + #print(N_, M_, D_, Lq_, L_, P_) + + tmp = torch.stack(sampling_value_list, dim=-2).flatten(-2) + + # print("MSDeformAttnFunction_attention_weights", attention_weights.shape, attention_weights.view(N_, M_, 1, Lq_, L_, P_)) + # print("MSDeformAttnFunction_sampl", tmp.shape, tmp.view(N_, M_, D_, Lq_, L_, P_)) + + #print("MSDeformAttnFunction_att*sampl", output.shape, output.view(N_, M_, D_, Lq_, L_, P_)[0, 2, :, 1, 3, 2]) + + # (N_*M_, D_, Lq_) -> (N_, M_*D_, Lq_) + output = output.sum(-1).view(N_, M_*D_, Lq_) + # (N_, Lq_, M_*D_) + return output.transpose(1, 2).contiguous() + + + + + + + + + + +N = 1 + +n_heads = 8 +d_model = 256 +size = np.array( (128, 128) ) +Len_q = 13 + + +# n_heads = 4 +# d_model = 4 +# size = np.array( (16, 16) ) + +# Len_q = 3 + +n_levels = 4 +num_sampling_points = 4 + +values = list() +spatial_shapes = list() +level_start_index = [0] + +for i in range(n_levels): + value = tf.random.uniform( shape=(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) ) * 0.01 + #value = tf.ones( shape=(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) ) + values.append(value) + spatial_shapes.append(size) + level_start_index.append(size[0]*size[1] + level_start_index[-1]) + + size = size//2 + +flatten_attn_weight = tf.random.uniform( (N, Len_q, n_heads, n_levels, num_sampling_points) ) + 1e-5 +flatten_attn_weight /= tf.reduce_sum(tf.reduce_sum(flatten_attn_weight, axis=-1, keepdims=True), axis=-2, keepdims=True) + +#flatten_attn_weight = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points) ) + + +flatten_sampling_loc = tf.random.uniform( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), minval=-0.1, maxval=1.1, dtype=tf.float32 ) +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) *10 #(127+0.999) #* math.pi /10 +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) *0.5 #* math.pi /10 +#flatten_sampling_loc = tf.ones( (N, Len_q, n_heads, n_levels, num_sampling_points, 2), dtype=tf.float32 ) * math.pi /10 + + + +level_start_index = np.array( level_start_index, dtype=np.int32 ) + +spatial_shapes = np.array( spatial_shapes, dtype=np.int32 ) + + +with tf.GradientTape(persistent=True) as g: + g.watch(flatten_sampling_loc) + g.watch(values) + g.watch(flatten_attn_weight) + + sampling_loc = tf.unstack(flatten_sampling_loc, axis=3) #(N, Len_q, n_heads, num_sampling_points) + flatten_value = tf.concat( [tf.reshape(v, (N, -1, n_heads, d_model//n_heads) ) for v in values], axis=1) + + py_res = MSDeformAttnFunction(values, sampling_loc, flatten_attn_weight) + + res = ms_deform_im2col( + flatten_value, # (N, Len_in, n_heads, d_model//n_heads) + spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + flatten_sampling_loc, # (N, Len_q, n_heads, n_levels, n_points, 2) + flatten_attn_weight # (N, Len_q, n_heads, n_level, n_points) + ) + +# Save tensors to npy for TensorRT plugin test +# np.save("./flatten_value.npy", flatten_value.numpy()) +# np.save("./spatial_shapes.npy", spatial_shapes) +# np.save("./level_start_index.npy", level_start_index) +# np.save("./flatten_sampling_loc.npy", flatten_sampling_loc.numpy()) +# np.save("./flatten_attn_weight.npy", flatten_attn_weight.numpy()) +# np.save("./ms_deform_im2col_out.npy", res.numpy()) + +def check_value(name, py_grad, cu_grad): + print(name, py_grad.shape) + print("\t min value :", tf.reduce_min(py_grad), tf.reduce_min(cu_grad)) + print("\t max value :", tf.reduce_max(py_grad), tf.reduce_max(cu_grad)) + print("\t mean value :", tf.reduce_mean(py_grad), tf.reduce_mean(cu_grad)) + print("\t std value :", tf.math.reduce_std(py_grad), tf.math.reduce_std(cu_grad)) + abs_err = tf.reduce_max(tf.abs(py_grad - cu_grad)) + #coord = tf.math.argmax(tf.abs(py_grad - cu_grad)) + print("\t max abs error :", abs_err) + abs_err = tf.reduce_mean(tf.abs(py_grad - cu_grad)) + print("\t mean abs error :", abs_err) + #rel_err = tf.reduce_max(tf.abs(py_grad - cu_grad)/(tf.math.sqrt( tf.abs(py_grad)*tf.abs(cu_grad) ) + 1e-3) ) + #print("\t max rel error :", rel_err) + +check_value("VALUE python / CUDA", py_res, res) +#print(py_res) +#print(res) + +pytorch_res = ms_deform_attn_core_pytorch( + torch.from_numpy(flatten_value.numpy()), + spatial_shapes, + torch.from_numpy(flatten_sampling_loc.numpy()), + torch.from_numpy(flatten_attn_weight.numpy()) ) + +check_value("VALUE pytorch / tensorflow", pytorch_res, res) + + +#print(pytorch_res) +check_value("GRAD Sampling Loc", g.gradient(py_res, flatten_sampling_loc), g.gradient(res, flatten_sampling_loc)) +check_value("GRAD Value", g.gradient(py_res, values[0]), g.gradient(res, values[0])) +check_value("GRAD Attention", g.gradient(py_res, flatten_attn_weight), g.gradient(res, flatten_attn_weight)) + + +#(N, int(size[0]), int(size[1]), n_heads, d_model//n_heads) +#py_gvalue = g.gradient(py_res, values[0]) +#cu_gvalue = g.gradient(res, values[0]) +#print("cu_gvalue", cu_gvalue) +#print("cu_gvalue", cu_gvalue) + +# args = [ +# flatten_value, # (N, Len_in, n_heads, d_model#n_heads) +# spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] +# level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] +# flatten_sampling_loc, # (N, Len_q, n_heads, n_levels, n_points, 2) +# flatten_attn_weight # (N, Len_q, n_heads, n_level, n_points) +# ] + +#CUDA +#numerical, theoric = tf.test.compute_gradient(ms_deform_im2col, args, delta=0.001) \ No newline at end of file diff --git a/detr_tf/inference.py b/detr_tf/inference.py index c82fa1c0..957233fd 100644 --- a/detr_tf/inference.py +++ b/detr_tf/inference.py @@ -65,24 +65,34 @@ def numpy_bbox_to_image(image, bbox_list, labels=None, scores=None, class_name=[ return image -def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center"): +def get_model_inference(m_outputs: dict, background_class, bbox_format="xy_center", threshold=None): + + #print('get model inference', [key for key in m_outputs]) + + # Detr or deformable + predicted_bbox = m_outputs["pred_boxes"][0] if "pred_boxes" in m_outputs else m_outputs["bbox_pred_boxes"][0] + predicted_labels = m_outputs["pred_logits"][0] if "pred_logits" in m_outputs else m_outputs["bbox_pred_logits"][0] + activation = "softmax" if "pred_boxes" in m_outputs else "sigmoid" + + if activation == "softmax": # Detr + softmax = tf.nn.softmax(predicted_labels) + predicted_scores = tf.reduce_max(softmax, axis=-1) + predicted_labels = tf.argmax(softmax, axis=-1) + indices = tf.where(predicted_labels != background_class) + indices = tf.squeeze(indices, axis=-1) + else: # Deformable Detr + sigmoid = tf.nn.sigmoid(predicted_labels) + predicted_scores = tf.reduce_max(sigmoid, axis=-1) + predicted_labels = tf.argmax(sigmoid, axis=-1) + threshold = 0.1 if threshold is None else threshold + indices = tf.where(predicted_scores > threshold) + indices = tf.squeeze(indices, axis=-1) - predicted_bbox = m_outputs["pred_boxes"][0] - predicted_labels = m_outputs["pred_logits"][0] - - softmax = tf.nn.softmax(predicted_labels) - predicted_scores = tf.reduce_max(softmax, axis=-1) - predicted_labels = tf.argmax(softmax, axis=-1) - - - indices = tf.where(predicted_labels != background_class) - indices = tf.squeeze(indices, axis=-1) predicted_scores = tf.gather(predicted_scores, indices) predicted_labels = tf.gather(predicted_labels, indices) predicted_bbox = tf.gather(predicted_bbox, indices) - if bbox_format == "xy_center": predicted_bbox = predicted_bbox elif bbox_format == "xyxy": diff --git a/detr_tf/networks/custom_layers.py b/detr_tf/networks/custom_layers.py index 1a0c524e..bfe0b967 100644 --- a/detr_tf/networks/custom_layers.py +++ b/detr_tf/networks/custom_layers.py @@ -28,23 +28,27 @@ def compute_output_shape(self, input_shape): return input_shape + class Linear(tf.keras.layers.Layer): ''' Use this custom layer instead of tf.keras.layers.Dense to allow loading converted PyTorch Dense weights that have shape (output_dim, input_dim) ''' - def __init__(self, output_dim, **kwargs): + def __init__(self, output_dim, kernel_initializer=tf.keras.initializers.GlorotUniform(), bias_initializer=tf.keras.initializers.GlorotUniform(), **kwargs): super().__init__(**kwargs) self.output_dim = output_dim + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + def build(self, input_shape): self.kernel = self.add_weight(name='kernel', shape=[self.output_dim, input_shape[-1]], - initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + initializer=self.kernel_initializer, trainable=True) self.bias = self.add_weight(name='bias', shape=[self.output_dim], - initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + initializer=self.bias_initializer, trainable=True) def call(self, x): return tf.matmul(x, self.kernel, transpose_b=True) + self.bias @@ -65,3 +69,37 @@ def build(self, input_shape): def call(self, x=None): return self.w + + +class ScaleLevelEmbedding(tf.keras.layers.Layer): + def __init__(self, num_level, embed_shape, **kwargs): + super().__init__(**kwargs) + self.embed_shape = embed_shape + self.num_level = num_level + + def build(self, input_shape): + self.w = self.add_weight(name='kernel', shape=(self.num_level, self.embed_shape), + initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.0), trainable=True) + + super().build(input_shape) + + def call(self, x=None): + return self.w + + +class MLP(tf.keras.layers.Layer): + def __init__(self, hidden_dim, output_dim, kernel_initializer=tf.keras.initializers.GlorotUniform(), bias_initializer=tf.keras.initializers.GlorotUniform(), **kwargs): + super().__init__(**kwargs) + + self.layer_0 = Linear(hidden_dim, name='layer_0') + self.layer_1 = Linear(hidden_dim, name='layer_1') + self.layer_2 = Linear(output_dim, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, name='layer_2') + + + def call(self, x, training=False): + x = tf.nn.relu(self.layer_0(x)) + x = tf.nn.relu(self.layer_1(x)) + x = self.layer_2(x) + + return x + diff --git a/detr_tf/networks/deformable_detr.py b/detr_tf/networks/deformable_detr.py new file mode 100644 index 00000000..f9ae4108 --- /dev/null +++ b/detr_tf/networks/deformable_detr.py @@ -0,0 +1,343 @@ +import pickle +import tensorflow as tf +import numpy as np +import time +import cv2 +import matplotlib.pyplot as plt +import os +import math +import json +from pathlib import Path +import tensorflow_addons as tfa +import functools +import collections + +from detr_tf.networks.deformable_transformer import DeformableTransformer +from detr_tf.networks.transformer import MultiHeadAttention +from detr_tf.networks.resnet_backbone import ResNet50Backbone +from detr_tf.networks.custom_layers import Linear, FixedEmbedding, ScaleLevelEmbedding, MLP +from detr_tf.networks.position_embeddings import PositionEmbeddingSine +from detr_tf.networks.transformer import Transformer +from detr_tf.networks.weights import load_weights + +class DeformableDETR(tf.keras.Model): + def __init__(self, + model_dim=256, + num_classes=91, + num_queries=300, + num_sampling_points=4, + backbone=None, + pos_encoder=None, + transformer=None, + num_encoder_layers=6, + num_decoder_layers=6, + return_intermediate_dec=True, + init_query_embedding=False, + batch_size=None, + use_mask_bn=False, + refine_bbox=True, + multiscale=True, + train_encoder=False, + **kwargs): + super().__init__(**kwargs) + self.num_queries = num_queries + + self.backbone = ResNet50Backbone(name='backbone') + + self.pos_encoder = pos_encoder or PositionEmbeddingSine( + num_pos_features=model_dim // 2, normalize=True, center=True) + + self.query_embed = FixedEmbedding((num_queries, model_dim*2), name='query_embed') + self.level_embed = ScaleLevelEmbedding(4, model_dim, name="level_embed", trainable=train_encoder) + + + self.multiscale = multiscale + + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed = list( + Linear(num_classes, + bias_initializer=tf.keras.initializers.Constant(bias_value), + name=f'class_embed_{i}') + for i in range(num_decoder_layers)) + + self.bbox_embed = list( + MLP(model_dim, + 4, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name=f'bbox_embed_{i}') + for i in range(num_decoder_layers)) + + # hack to force shared weight (different from pytorch cloning approach) + if not refine_bbox: + self.class_embed = [self.class_embed[0] for _ in range(num_decoder_layers)] + self.bbox_embed = [self.bbox_embed[0] for _ in range(num_decoder_layers)] + + self.transformer = transformer or DeformableTransformer( + query_embed_layer=self.query_embed, + level_embed=self.level_embed, + layer_position_embedding_sine=self.pos_encoder, + model_dim=model_dim, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + num_sampling_points=num_sampling_points, + return_intermediate_dec=return_intermediate_dec, + init_query_embedding=init_query_embedding, + class_embed=self.class_embed, + bbox_embed=self.bbox_embed, + refine_bbox=refine_bbox, + train_encoder=train_encoder, + name='transformer' + ) + self.model_dim = model_dim + + layer_norm = tf.keras.layers.BatchNormalization + + self.input_proj_0 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/0/0', trainable=train_encoder) + self.input_proj_gn_0 = layer_norm(name="input_proj_gn/0/1", trainable=train_encoder) + + if multiscale: + self.input_proj_1 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/1/0', trainable=train_encoder) + self.input_proj_2 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/2/0', trainable=train_encoder) + self.input_proj_3 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=3, strides=2, name='input_proj/3/0', trainable=train_encoder) + + self.input_proj_gn_1 = layer_norm(name="input_proj_gn/1/1", trainable=train_encoder) + self.input_proj_gn_2 = layer_norm(name="input_proj_gn/2/1", trainable=train_encoder) + self.input_proj_gn_3 = layer_norm(name="input_proj_gn/3/1", trainable=train_encoder) + + #self.activation = tf.keras.layers.ReLU() + + self.num_decoder_layers = num_decoder_layers + + + def call(self, inp, training=False, post_process=False): + x = inp + backbone_outputs = self.backbone(x, training=training) + x2, x1, x0, _ = backbone_outputs + + if self.multiscale: + src_proj_outputs = [self.input_proj_gn_0(self.input_proj_0(x0)), \ + self.input_proj_gn_1(self.input_proj_1(x1)), \ + self.input_proj_gn_2(self.input_proj_2(x2)), \ + self.input_proj_gn_3(tf.keras.layers.ZeroPadding2D(1)(self.input_proj_3(x2)))] + else: + src_proj_outputs = [self.input_proj_gn_0(self.input_proj_0(x2))] + + masks = list(tf.zeros([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1], tf.float32) for x in src_proj_outputs) + + decoder, encoder, outputs_coord = self.transformer(src_proj_outputs, + masks, + training=training) + + outputs_class = list(class_embed(x) for class_embed, x in zip(self.class_embed, tf.split(decoder, self.num_decoder_layers)) ) + + output = {'bbox_pred_logits': outputs_class[-1], + 'bbox_pred_boxes': outputs_coord[-1]} + + if post_process: + output = self.post_process(output) + + a = self.query_embed(None) + + return output + + +class DetrClassHead(tf.keras.layers.Layer): + + def __init__(self, detr, include_top, nb_class=None, refine_bbox=False, **kwargs): + """ + """ + super().__init__(name="detr_class_head", **kwargs) + self.include_top = include_top + if self.include_top: + if refine_bbox: + self.layer_class_embed = list(detr.get_layer(f'class_embed_{i}') for i in range(detr.num_decoder_layers)) + else: + #shared weights + self.layer_class_embed = list(detr.get_layer(f'class_embed_0') for _ in range(detr.num_decoder_layers) ) + else: + # Setup the new layers + if refine_bbox: + self.layer_class_embed = list(tf.keras.layers.Dense(nb_class, name=f"class_embed_{i}") for i in range(detr.num_decoder_layers) ) + else: + layer = tf.keras.layers.Dense(nb_class, name=f"class_embed_0") + self.layer_class_embed = list(layer for i in range(detr.num_decoder_layers) ) + + def call(self, decoder_state): + outputs = {} + + # Output class + outputs_class = [l(s) for l, s in zip(self.layer_class_embed, tf.unstack(decoder_state, axis=0))] + + outputs = {'bbox_pred_logits': outputs_class[-1]} + + outputs["bbox_aux"] = [] + for out_class in outputs_class: + outputs["bbox_aux"].append({ + "bbox_pred_logits": out_class + }) + + return outputs + + + def build(self, input_shape=None, **kwargs): + super().build(input_shape, **kwargs) + + + +def get_detr_core(detr, backbone, model_dim, tf_backbone=False, multiscale=True): + """ DETR Core is made of the backbone and the transformer part without the + heads + """ + + layer_transformer = detr.get_layer("transformer") + + #### Set ops + if not tf_backbone: + image_input = tf.keras.Input((None, None, 3)) + backbone_outputs = backbone(image_input) + x2, x1, x0, _ = backbone_outputs + else: + image_input = backbone.inputs + _ = backbone.get_layer("conv1_relu").output #/2 + _ = backbone.get_layer("conv2_block3_out").output #/4 + x0 = backbone.get_layer("conv3_block4_out").output #/8 + x1 = backbone.get_layer("conv4_block6_out").output #/16 + x2 = backbone.get_layer("conv5_block3_out").output #/32 + backbone_outputs = x2, x1, x0, _ + + if multiscale: + src_proj_outputs = list((None,None, None, None)) + for i, tensor in enumerate([x0, x1, x2, x2]): + input_proj_layer = detr.get_layer(f'input_proj/{i}/0') + input_proj_gn_layer = detr.get_layer(f'input_proj_gn/{i}/1') + if i == 3: + tensor = tf.keras.layers.ZeroPadding2D(1)(tensor) + tensor = input_proj_layer(tensor) + + src_proj_outputs[i] = input_proj_gn_layer(tensor) + else: + input_proj_layer = detr.get_layer(f'input_proj/0/0') + input_proj_gn_layer = detr.get_layer(f'input_proj_gn/0/1') + src_proj_outputs = [input_proj_gn_layer(input_proj_layer(x2))] + + masks = list(tf.zeros([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1], tf.float32) for x in src_proj_outputs) + + decoder, encoder, outputs_coord = layer_transformer(src_proj_outputs, masks) + + detr = tf.keras.Model(image_input, [outputs_coord, decoder, encoder, src_proj_outputs, backbone_outputs], name="detr_core") + + return detr + + +def get_deformable_detr_model( + config, + + include_top=False, + include_bbox=True, + nb_keypoint=None, + nb_class=None, + weights=None, + tf_backbone=False, + + batch_size=None, + num_decoder_layers=6, + num_encoder_layers=6, + + use_mask_bn=False, + + + refine_bbox=False, + return_intermediate_dec=True, + model_dim=256, + multiscale=True, + include_bbox_3d=False, + bbox_3d_config=None, + train_encoder=True, + + ): + if weights == "deformable-detr-refine_bbox" and nb_class is not None and include_top and nb_class != 91: + raise ValueError('"deformable_detr" weights are trained with 92 outputs. Do not include the network top to set this number of class') + elif weights == "deformable_detr" and nb_class is None: + nb_class = 91 + + if weights != "deformable-detr-refine_bbox" and refine_bbox and weights is not None: + raise ValueError('"Trying to instanciate deformable_detr_bbox_refined with deformable_detr weights') + + init_query_embedding = False #if weights == "deformable_detr" else True + # Load model and weights + detr = DeformableDETR(num_classes=nb_class, + num_decoder_layers=num_decoder_layers, + num_encoder_layers=num_encoder_layers, + batch_size=batch_size, + init_query_embedding=init_query_embedding, + use_mask_bn=use_mask_bn, + + refine_bbox=refine_bbox, + return_intermediate_dec=return_intermediate_dec, + + multiscale=multiscale, + train_encoder=train_encoder) + + image_shape = (None, None, 3) + + # Backbone + if not tf_backbone: + backbone = detr.get_layer("backbone") + else: + config.normalized_method = "tf_resnet" + backbone = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_shape=(None, None, 3)) + + if weights is not None: + load_weights(detr, weights) + + # Backbone + if not tf_backbone: + backbone = detr.get_layer("backbone") + else: + backbone = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_shape=image_shape) + + # Get detr core: backbone + transformer + image_input = tf.keras.Input(image_shape, batch_size=batch_size) + + detr_core_outputs = get_detr_core(detr, backbone, model_dim, tf_backbone=tf_backbone, multiscale=multiscale)(image_input) + + if include_top is False and nb_class is None: + return tf.keras.Model(image_input, detr_core_outputs, name="detr_core") + + + outputs_coord, decoder_state, encoder_state, src_proj_output, backbone_outs = detr_core_outputs + + outputs = {"backbone_outs":list(backbone_outs), "src_proj_output":list(src_proj_output), "encoder_state":encoder_state} + + if include_bbox: + + outputs['bbox_pred_boxes'] = outputs_coord[-1] + outputs["bbox_aux"] = [] + for i in range(0, outputs_coord.shape[0] - 1): + outputs["bbox_aux"].append({ + "bbox_pred_boxes": outputs_coord[i] + }) + + # Add bbox head + class_head = DetrClassHead(detr, include_top=include_top, nb_class=nb_class, refine_bbox=refine_bbox) + bbox_outputs = class_head(decoder_state) + config.add_heads([class_head]) + update(outputs, bbox_outputs) + + deformable_detr = tf.keras.Model(image_input, outputs, name="deformable_detr") + return deformable_detr + +def update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = update(d.get(k, {}), v) + else: + d[k] = v + return d + + +if __name__ == "__main__": + main() diff --git a/detr_tf/networks/deformable_transformer.py b/detr_tf/networks/deformable_transformer.py new file mode 100644 index 00000000..4086c4c7 --- /dev/null +++ b/detr_tf/networks/deformable_transformer.py @@ -0,0 +1,583 @@ +import tensorflow as tf +from tensorflow.keras.layers import Dropout, Activation, LayerNormalization +import math +from .custom_layers import Linear +from .transformer import MultiHeadAttention + + +USE_CUDA_MS_DEFORM_IM2COL = True + +if USE_CUDA_MS_DEFORM_IM2COL: + from detr_tf.custom_ops.ms_deform_attn import ms_deform_im2col +else: + from detr_tf.custom_ops.ms_deform_attn.ms_deform_attn import MSDeformAttnFunction + +class DeformableTransformer(tf.keras.layers.Layer): + def __init__(self, + layer_position_embedding_sine, + level_embed, + class_embed, + bbox_embed, + query_embed_layer=None, + model_dim=256, + num_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + num_sampling_points=4, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=False, + init_query_embedding=False, + use_track_query=False, + refine_bbox=False, + multiscale=True, + train_encoder=True, + **kwargs): + + super().__init__(**kwargs) + + self.model_dim = model_dim + self.num_heads = num_heads + + self.layer_position_embedding_sine = layer_position_embedding_sine + self.query_embed_layer = query_embed_layer + + self.level_embed = level_embed + + self.class_embed = class_embed + self.bbox_embed = bbox_embed + + self.init_query_embedding = init_query_embedding + + self.multiscale = multiscale + + self.encoder = DeformableEncoder(model_dim, num_heads, dim_feedforward, + dropout, activation, + num_encoder_layers, num_sampling_points=num_sampling_points, name='encoder', trainable=train_encoder) + + self.decoder = DeformableDecoder(class_embed, bbox_embed, model_dim, num_heads, dim_feedforward, + dropout, activation, + num_decoder_layers, + name='decoder', + num_sampling_points=num_sampling_points, refine_bbox=refine_bbox, + return_intermediate=return_intermediate_dec, use_track_query=use_track_query) + + + if self.init_query_embedding: + raise NotImplementedError() + self.query_encoding = self.add_weight(name='query_embedding', shape=(100, 256), + initializer=tf.keras.initializers.GlorotUniform(), trainable=True) + else: + self.init_query_embedding = init_query_embedding + + self.reference_points = Linear(2, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name="reference_points") + + def get_reference_points(self, spatial_shapes): + reference_points_list = [] + for lvl, (H_W_) in enumerate(spatial_shapes): + H_, W_ = tf.unstack(H_W_) + ref_y, ref_x = tf.meshgrid(tf.linspace(0.5, tf.cast(H_, tf.float32) - 0.5, H_), + tf.linspace(0.5, tf.cast(W_, tf.float32) - 0.5, W_), indexing='ij') + + ref_y = ref_y / tf.cast(H_, tf.float32) + ref_x = ref_x / tf.cast(W_, tf.float32) + + ref = tf.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + # L, (H, W, 2) + return reference_points_list + + def call(self, source, mask, track_query=None, track_query_mask=None, track_query_reference_points=None, training=False): + + N = tf.shape(source[0])[0] + + + pos_encoding = list(self.layer_position_embedding_sine(m) for m in mask) + + if self.init_query_embedding: + # Panoptic embedding + query_encoding = self.query_encoding + else: + # Detection embedding + query_encoding = self.query_embed_layer(None) + + query_encoding, target = tf.split(query_encoding, 2, axis=1) + query_encoding = tf.expand_dims(query_encoding, axis=1) + + if self.level_embed is not None: + level_embed = self.level_embed(None) + lvl_pos_embed = list(level_embed[lvl, None, :] + tf.reshape(p, (N, -1, self.model_dim) ) for lvl, p in enumerate(pos_encoding) ) # N, (H*W), C + 1, 1, C + lvl_pos_embed_flatten = tf.concat(lvl_pos_embed, axis=1) # N, sum_L(H*W), C + else: + lvl_pos_embed_flatten = None + + # L, 2 + input_spatial_shapes = list(tf.shape(src)[1:3] for src in source) + encoder_reference_points = self.get_reference_points(input_spatial_shapes) + input_spatial_shapes = tf.stack(input_spatial_shapes, 0) + + # L, + input_level_start_index = tf.math.reduce_prod(input_spatial_shapes, axis=1) + input_level_start_index = tf.math.cumsum(input_level_start_index, axis=0, exclusive=True) + + # L, (H, W, 2) -> L, (1, H*W, 2) + encoder_reference_points = list( tf.reshape(rp_l, (1, -1, 2)) for rp_l in encoder_reference_points) + encoder_reference_points = tf.concat(encoder_reference_points, axis=1) + + + + #Flatten sources + source = [tf.reshape(s, (N, -1, self.model_dim) ) for s in source] + source = tf.concat(source, axis=1) + memory = self.encoder(source, encoder_reference_points, source_key_padding_mask=mask, + pos_encoding=lvl_pos_embed_flatten, + training=training, + source_spatial_shapes=input_spatial_shapes, + source_level_start_index=input_level_start_index) + + + decoder_reference_points = tf.math.sigmoid(self.reference_points(query_encoding)) + decoder_reference_points = tf.tile(decoder_reference_points, [1, N, 1]) + + if track_query_reference_points is not None: + decoder_reference_points = tf.concat([track_query_reference_points, decoder_reference_points], axis=0) + + target = tf.reshape(target, (300, 1, self.model_dim) ) + target = tf.tile(target, [1, N, 1]) + + + hs, reference_points = self.decoder(target, memory, decoder_reference_points, memory_key_padding_mask=mask, + pos_encoding=lvl_pos_embed_flatten, query_encoding=query_encoding, + track_query=track_query, track_query_mask=track_query_mask, + memory_spatial_shapes=input_spatial_shapes, + memory_level_start_index=input_level_start_index, + training=training) + + return tf.transpose(hs, [0, 2, 1, 3]), tf.transpose(memory, (1, 0, 2)), tf.transpose(reference_points, [0, 2, 1, 3]) + + +class DeformableEncoder(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, dim_feedforward=2048, + dropout=0.1, activation='relu', + num_encoder_layers=6, num_sampling_points=4, **kwargs): + super().__init__(**kwargs) + + self.enc_layers = [DeformableEncoderLayer(model_dim, num_heads, num_sampling_points, dim_feedforward, + dropout, activation, + name='layer_%d'%i) + for i in range(num_encoder_layers)] + + + def call(self, source, reference_points, mask=None, source_key_padding_mask=None, + pos_encoding=None, track_query=None, + source_spatial_shapes=None, source_level_start_index=None, training=False): + x = source + + for l_id, layer in enumerate(self.enc_layers): + x = layer(x, reference_points, source_mask=mask, source_key_padding_mask=source_key_padding_mask, + pos_encoding=pos_encoding, input_spatial_shapes=source_spatial_shapes, input_level_start_index=source_level_start_index, training=training) + + return x + + +class DeformableDecoder(tf.keras.layers.Layer): + def __init__(self, class_embed, bbox_embed, model_dim=256, num_heads=8, dim_feedforward=2048, + dropout=0.1, activation='relu', + num_decoder_layers=6, num_sampling_points=4, return_intermediate=False, use_track_query=False, refine_bbox=False, **kwargs): + super().__init__(**kwargs) + + self.dec_layers = [DeformableDecoderLayer(model_dim, num_heads, num_sampling_points, dim_feedforward, + dropout, activation, use_track_query=use_track_query, + name='layer_%d'%i) + for i in range(num_decoder_layers)] + + self.class_embed = class_embed + self.bbox_embed = bbox_embed + + self.refine_bbox = refine_bbox + self.return_intermediate = return_intermediate + + + def call(self, target, memory, reference_points, target_mask=None, memory_mask=None, + target_key_padding_mask=None, memory_key_padding_mask=None, memory_spatial_shapes=None, memory_level_start_index=None, + pos_encoding=None, query_encoding=None, track_query=None, track_query_mask=None, training=False): + + + x = target + intermediate = [] + intermediate_reference_points = [] + + new_reference_points = reference_points + + for l_id, layer in enumerate(self.dec_layers): + # if the tracking is not use + # track_query we'll simply stay None. + x, track_query = layer(x, memory, reference_points, + target_mask=target_mask, + memory_mask=memory_mask, + target_key_padding_mask=target_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos_encoding=pos_encoding, + track_query=track_query, + track_query_mask=track_query_mask, + query_encoding=query_encoding, + memory_spatial_shapes=memory_spatial_shapes, + memory_level_start_index=memory_level_start_index) + + if track_query is not None: + out = tf.concat([track_query, x], axis=0) + else: + out = x + + tmp = self.bbox_embed[l_id](out) + if self.refine_bbox: + new_reference_points = inverse_sigmoid(new_reference_points) + else: + new_reference_points = inverse_sigmoid(reference_points) + + if new_reference_points.shape[-1] == 4: + new_reference_points = tmp + new_reference_points + elif new_reference_points.shape[-1] == 2: + xy = tmp[..., :2] + new_reference_points + hw = tmp[..., 2:] + new_reference_points = tf.concat([xy, hw], axis=-1) + else: + raise ValueError() + + + new_reference_points = tf.math.sigmoid(new_reference_points) + + if self.refine_bbox: + reference_points = tf.stop_gradient(new_reference_points) + + if self.return_intermediate: + intermediate.append(out) + intermediate_reference_points.append(new_reference_points) + + + if self.return_intermediate: + return tf.stack(intermediate, axis=0), tf.stack(intermediate_reference_points, axis=0) + + return out, reference_points + + +class DeformableEncoderLayer(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, num_sampling_points=4, dim_feedforward=2048, + dropout=0.1, activation='relu', + **kwargs): + super().__init__(**kwargs) + + self.self_attn = MSDeformableAttention(model_dim, num_heads, num_sampling_points, dropout=dropout, + name='self_attn') + + self.dropout = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + + self.activation = Activation(activation) + + self.model_dim = model_dim + + self.linear1 = Linear(dim_feedforward, name='linear1') + self.linear2 = Linear(model_dim, name='linear2') + + self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') + self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') + + + def call(self, source, reference_points, source_mask=None, source_key_padding_mask=None, + pos_encoding=None, input_spatial_shapes=None, input_level_start_index=None, training=False): + """ + :param source (N, H, W, C) + :param pos_encoding (1, sum_L(H*W), C) + :param reference_points (H, W, 2) + + :return output (N, sum_L(H*W), C) + """ + + N = tf.shape(source[0])[0] + C = self.model_dim + + if pos_encoding is None: + query = source + else: + query = source + pos_encoding + + # # Multi-scale level embedding L, (N, H, W, C) + # query = list(q + level_embed[lvl, None, None, :] for lvl, q in enumerate(query)) + + # # Flatten ¤¤¤¤ + # # L, (N, H*W, C) + # query = list(tf.reshape(q, (N, -1, C) ) for q in query) + # # (N, sum_L{H_*W_}, C) + # query = tf.concat(query, axis=1) + + # (N, Length_{query}, C) + + #print("query", query.shape) + attn_source = self.self_attn(query, reference_points, source, input_spatial_shapes, input_level_start_index) + #print("attn_source", attn_source.shape) + # src = list(tf.reshape(s, (N, -1, C) ) for s in source) + # # (N, sum_L{H_*W_}, C) + # src = tf.concat(src, axis=1) + + source += self.dropout(attn_source, training=training) + source = self.norm1(source) + + #forward_ffn + x = self.linear1(source) + x = self.activation(x) + x = self.dropout2(x, training=training) + x = self.linear2(x) + source += self.dropout3(x, training=training) + source = self.norm2(source) + + # #Unflatten ¤¤¤¤ + # split_size = list(iss[0]*iss[1] for iss in input_spatial_shapes) + # # L, (N, H*W, 2) + # src = tf.split(src, split_size, axis=1) + # # L, (N, H, W, 2) + # src = list(tf.reshape(el, (N, iss[0], iss[1], C) ) for iss, el in zip(input_spatial_shapes, src)) + + return source + + + +class DeformableDecoderLayer(tf.keras.layers.Layer): + def __init__(self, model_dim=256, num_heads=8, num_sampling_points=4, dim_feedforward=2048, + dropout=0.1, activation='relu', use_track_query=False, + **kwargs): + super().__init__(**kwargs) + + self.self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout, + name='self_attn') + self.cross_attn = MSDeformableAttention(model_dim, num_heads, dropout=dropout, + name='cross_attn') + + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + self.dropout3 = Dropout(dropout) + self.dropout4 = Dropout(dropout) + + self.activation = Activation(activation) + + self.linear1 = Linear(dim_feedforward, name='linear1') + self.linear2 = Linear(model_dim, name='linear2') + + self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1') + self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2') + self.norm3 = LayerNormalization(epsilon=1e-5, name='norm3') + + self.use_track_query = use_track_query + + # if self.use_track_query: + # self.dropout = Dropout(dropout) + # self.track_query_norm = LayerNormalization(epsilon=1e-5, name='track_query_norm') + # self.track_query_self_attn = MultiHeadAttention(model_dim, num_heads, dropout=dropout, name='track_query_self_attn') + + + def call(self, target, memory, reference_points, target_mask=None, memory_mask=None, + target_key_padding_mask=None, memory_key_padding_mask=None, memory_spatial_shapes=None, memory_level_start_index=None, + pos_encoding=None, track_query=None, track_query_mask=None, query_encoding=None, level_embed=None, training=False): + + + if track_query is not None: + # track_query_query = track_query + # track_query_key = track_query + # track_query_target = track_query + + if target_key_padding_mask is None: + target_shape = tf.shape(target) + target_key_padding_mask = tf.zeros((target_shape[1], target_shape[0])) + + # track_query_attn_target = self.track_query_self_attn((track_query_query, track_query_key, track_query_target), key_padding_mask=track_query_mask, need_weights=False) + + # track_query_target += self.dropout(track_query_attn_target, training=training) + # track_query_target = self.track_query_norm(track_query_target) + nb_trackquery = tf.shape(track_query)[0] + + # Pad on the left the original query + query_encoding = tf.pad(query_encoding, [[nb_trackquery, 0], [0, 0], [0, 0]], "CONSTANT" ) + # Concat with the track query on the left + target = tf.concat([track_query, target], axis=0) + target_key_padding_mask = tf.concat([track_query_mask, target_key_padding_mask], axis=1) + + # If we use the track query, the query encoding is now padded with zeros for the track queries + # query_tgt = target + query_encoding + query_tgt = key_tgt = target + query_encoding + + attn_target = self.self_attn((query_tgt, key_tgt, target), attn_mask=target_mask, + key_padding_mask=target_key_padding_mask, + need_weights=False) + + target += self.dropout2(attn_target, training=training) + target = self.norm2(target) + + query_tgt = target + query_encoding + + query_tgt = tf.transpose(query_tgt, (1, 0, 2) ) + reference_points = tf.transpose(reference_points, (1, 0, 2) ) + + attn_target2 = self.cross_attn(query_tgt, reference_points, memory, + input_spatial_shapes=memory_spatial_shapes, input_level_start_index=memory_level_start_index) + attn_target2 = tf.transpose(attn_target2, (1, 0, 2) ) + + target += self.dropout1(attn_target2, training=training) + target = self.norm1(target) + + x = self.linear1(target) + x = self.activation(x) + x = self.dropout3(x, training=training) + x = self.linear2(x) + target += self.dropout4(x, training=training) + target = self.norm3(target) + + if track_query is not None: + n_track_query = target[:nb_trackquery] + target = target[nb_trackquery:] + return target, n_track_query + else: + return target, None + + +class SamplingOffsetBiasInitializer(tf.keras.initializers.Initializer): + + def __init__(self, num_heads, num_level, n_points): + self.num_heads = num_heads + self.num_level = num_level + self.n_points = n_points + + def __call__(self, shape, dtype=None, **kwargs): + thetas = tf.range(self.num_heads, dtype=tf.float32) * (2.0 * math.pi / self.num_heads) + grid_init = tf.stack([tf.math.cos(thetas), tf.math.sin(thetas)], axis=-1) + grid_init = grid_init / tf.math.reduce_max(tf.abs(grid_init), axis=-1, keepdims=True)[0] + grid_init = tf.reshape(grid_init, (self.num_heads, 1, 1, 2) ) + # self.num_heads, self.num_level, self.n_points, 2 + grid_init = tf.tile(grid_init, (1, self.num_level, self.n_points, 1) ) + + scaling = tf.range(self.n_points, dtype = tf.float32) + 1.0 + scaling = tf.reshape(scaling, (1, 1, self.n_points , 1) ) + grid_init = grid_init * scaling + + grid_init = tf.reshape(grid_init, (-1,)) + + return grid_init + + + +class MSDeformableAttention(tf.keras.layers.Layer): + def __init__(self, model_dim, num_heads, num_sampling_points = 4, num_level=4, dropout=0.0, **kwargs): + super().__init__(**kwargs) + + self.model_dim = model_dim + self.num_heads = num_heads + + self.num_level = num_level + self.num_sampling_points = num_sampling_points + + assert model_dim % num_heads == 0 + self.head_dim = model_dim // num_heads + + self.dropout = Dropout(rate=dropout) + + self.im2col_step = 64 + + + def build(self, input_shapes): + + self.sampling_offsets = Linear(self.num_heads * self.num_level * self.num_sampling_points * 2, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=SamplingOffsetBiasInitializer(self.num_heads, self.num_level, self.num_sampling_points), + name="sampling_offsets") + + self.attention_weights = Linear(self.num_heads * self.num_level * self.num_sampling_points, + kernel_initializer=tf.keras.initializers.Zeros(), + bias_initializer=tf.keras.initializers.Zeros(), + name="attention_weights") + + self.value_proj = Linear(self.model_dim, bias_initializer=tf.keras.initializers.Zeros(), name="value_proj") + + self.output_proj = Linear(self.model_dim, name="output_proj") + + + def call(self, query, reference_points, inputs, input_spatial_shapes=None, input_level_start_index=None, input_padding_mask=None, training=False): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, 4), add additional (w, h) to form reference boxes + + :param inputs (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + or + :param inputs lvl, (N, H_l, W_l, C) + + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + :return output (N, Length_{query}, C) + """ + + #debug purpose + unstack_size = list(iss[0]*iss[1] for iss in tf.unstack(input_spatial_shapes, axis=0)) + unstack_shape = list( (iss[0], iss[1]) for iss in tf.unstack(input_spatial_shapes, axis=0)) + + N, Len_q, C = tf.unstack(tf.shape(query)) + + N, Len_in, _ = tf.unstack(tf.shape(inputs)) + value = self.value_proj(inputs) + value = tf.reshape(value, (N, Len_in, self.num_heads, self.head_dim)) + + + sampling_offsets = self.sampling_offsets(query) + sampling_offsets = tf.reshape(sampling_offsets, (N, Len_q, self.num_heads, self.num_level, self.num_sampling_points, 2) ) + + + attention_weights = self.attention_weights(query) + attention_weights = tf.reshape(attention_weights, (N, Len_q, self.num_heads, self.num_level * self.num_sampling_points) ) + attention_weights = tf.nn.softmax(attention_weights, axis=-1) + attention_weights = tf.reshape(attention_weights, (N, Len_q, self.num_heads, self.num_level, self.num_sampling_points) ) + + + # (N, Len_q, num_heads, num_level, num_sampling_points, _) + if reference_points.shape[-1] == 2: + offset_normalizer = tf.cast(tf.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1), tf.float32) + sampling_locations = reference_points[:, :, None, None, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, None, None, :2] + sampling_offsets / self.num_sampling_points * reference_points[:, :, None, None, None, 2:] * 0.5 + else: + raise ValueError(f"reference_points shape must be defined, got {reference_points.shape[-1]}") + + if USE_CUDA_MS_DEFORM_IM2COL: + # Flatten and call custom op ! + output = ms_deform_im2col( + value, # (N, Len_in, n_heads, d_model#n_heads) + input_spatial_shapes, # (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + input_level_start_index, # (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + sampling_locations, # (N, Len_q, n_heads, n_levels, n_points, 2) + attention_weights # (N, Len_q, num_heads, n_level, num_sampling_points) + ) + else: + #Unflatten + value = tf.split(value, num_or_size_splits=unstack_size, axis=1) + value = list(tf.reshape(v, (N, shape[0], shape[1], self.num_heads, self.head_dim) ) for v, shape in zip(value, unstack_shape) ) + + sampling_loc = tf.unstack(sampling_locations, axis=3) #(N, Len_q, n_heads, num_sampling_points) + + output = MSDeformAttnFunction(value, sampling_loc, attention_weights) + + + output = self.output_proj(output) + + + return output + +def inverse_sigmoid(x, eps=1e-5): + x = tf.clip_by_value(x, 0.0, 1.0) + x1 = tf.clip_by_value(x, eps, 1.0) + + x2 = (1 - x) + x2 = tf.clip_by_value(x2, eps, 1.0) + return tf.math.log(x1/x2) diff --git a/detr_tf/networks/detr.py b/detr_tf/networks/detr.py index 4a52529b..04f09156 100644 --- a/detr_tf/networks/detr.py +++ b/detr_tf/networks/detr.py @@ -168,7 +168,7 @@ def get_detr_model(config, include_top=False, nb_class=None, weights=None, tf_ba bbox_embed_linear3 = detr.get_layer('bbox_embed_2') activation = detr.get_layer("re_lu") - x = backbone(image_input) + x, _, _, _ = backbone(image_input) # Resize the mask to the same size of the backbone outptu masks = tf.image.resize(image_mask, (tf.shape(x)[1], tf.shape(x)[2]), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) diff --git a/detr_tf/networks/resnet_backbone.py b/detr_tf/networks/resnet_backbone.py index d81a5872..5e0fbf1f 100644 --- a/detr_tf/networks/resnet_backbone.py +++ b/detr_tf/networks/resnet_backbone.py @@ -25,11 +25,11 @@ def call(self, x): x = self.pad2(x) x = self.maxpool(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - return x + l1 = self.layer1(x) + l2 = self.layer2(l1) + l3 = self.layer3(l2) + l4 = self.layer4(l3) + return l4, l3, l2, l1 class ResNet50Backbone(ResNetBase): diff --git a/detr_tf/networks/weights.py b/detr_tf/networks/weights.py index 87015558..c59f5cef 100644 --- a/detr_tf/networks/weights.py +++ b/detr_tf/networks/weights.py @@ -7,6 +7,16 @@ "https://storage.googleapis.com/visualbehavior-publicweights/detr/checkpoint", "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.data-00000-of-00001", "https://storage.googleapis.com/visualbehavior-publicweights/detr/detr.ckpt.index" + ], + "deformable-detr": [ + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/checkpoint", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/deformable-detr.ckpt.data-00000-of-00001", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr/deformable-detr.ckpt.index" + ], + "deformable-detr-refine_bbox": [ + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/checkpoint", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/deformable-detr-refine_bbox.ckpt.data-00000-of-00001", + "https://storage.googleapis.com/visualbehavior-publicweights/deformable-detr-refine_bbox/deformable-detr-refine_bbox.ckpt.index" ] } diff --git a/webcam_inference.py b/webcam_inference.py index 63871187..247201bf 100644 --- a/webcam_inference.py +++ b/webcam_inference.py @@ -3,19 +3,29 @@ import cv2 from detr_tf.training_config import TrainingConfig, training_config_parser + from detr_tf.networks.detr import get_detr_model +from detr_tf.networks.deformable_detr import get_deformable_detr_model + from detr_tf.data import processing from detr_tf.data.coco import COCO_CLASS_NAME from detr_tf.inference import get_model_inference, numpy_bbox_to_image + @tf.function -def run_inference(model, images, config): - m_outputs = model(images, training=False) - predicted_bbox, predicted_labels, predicted_scores = get_model_inference(m_outputs, config.background_class, bbox_format="xy_center") +def run_inference(model, images, config, use_mask=True): + + if use_mask: + mask = tf.zeros((1, images.shape[1], images.shape[2], 1)) + m_outputs = model((images, mask), training=False) + else: + m_outputs = model(images, training=False) + + predicted_bbox, predicted_labels, predicted_scores = get_model_inference(m_outputs, config.background_class, bbox_format="xy_center", threshold=0.2) return predicted_bbox, predicted_labels, predicted_scores -def run_webcam_inference(detr): +def run_webcam_inference(model, use_mask=True): cap = cv2.VideoCapture(0) @@ -27,7 +37,7 @@ def run_webcam_inference(detr): model_input = processing.normalized_images(model_input, config) # Run inference - predicted_bbox, predicted_labels, predicted_scores = run_inference(detr, np.expand_dims(model_input, axis=0), config) + predicted_bbox, predicted_labels, predicted_scores = run_inference(model, np.expand_dims(model_input, axis=0), config, use_mask=use_mask) frame = frame.astype(np.float32) frame = frame / 255 @@ -52,8 +62,11 @@ def run_webcam_inference(detr): config.update_from_args(args) # Load the model with the new layers to finetune - detr = get_detr_model(config, include_top=True, weights="detr") - config.background_class = 91 + #detr = get_detr_model(config, include_top=True, weights="detr") + #config.background_class = 91 + + deformable_detr = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") + deformable_detr.summary() # Run webcam inference - run_webcam_inference(detr) + run_webcam_inference(deformable_detr, use_mask=False) From cfbdd8dcdbac2916bf62915990de751d63773fa6 Mon Sep 17 00:00:00 2001 From: Thibault Date: Tue, 15 Jun 2021 11:39:23 +0200 Subject: [PATCH 4/5] Deformable inference + Detr --- detr_tf/data/coco.py | 6 +++--- detr_tf/training_config.py | 2 +- eval.py | 2 +- webcam_inference.py | 25 +++++++++++++++++-------- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/detr_tf/data/coco.py b/detr_tf/data/coco.py index 938b776e..d5f19bd9 100644 --- a/detr_tf/data/coco.py +++ b/detr_tf/data/coco.py @@ -112,7 +112,7 @@ def iter_tuple_to_dict(data): } -def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None, shuffle=True): +def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_file=None, img_dir=None, shuffle_data=True): """ Load a coco dataset Parameters @@ -158,11 +158,11 @@ def load_coco_dataset(config, batch_size, augmentation=False, ann_dir=None, ann_ # Setup the data pipeline img_ids = coco.getImgIds() - if shuffle: + if shuffle_data: shuffle(img_ids) dataset = tf.data.Dataset.from_tensor_slices(img_ids) # Shuffle the dataset - if shuffle: + if shuffle_data: dataset = dataset.shuffle(1000) # Retrieve img and labels diff --git a/detr_tf/training_config.py b/detr_tf/training_config.py index c3b254d2..63a0adbb 100644 --- a/detr_tf/training_config.py +++ b/detr_tf/training_config.py @@ -106,7 +106,7 @@ def update_from_args(self, args): """ args = vars(args) for key in args: - if isinstance(getattr(self, key), tf.Variable): + if isinstance(getattr(self, key, None), tf.Variable): getattr(self, key).assign(args[key]) else: setattr(self, key, args[key]) diff --git a/eval.py b/eval.py index 899372cf..93b9bc7c 100644 --- a/eval.py +++ b/eval.py @@ -92,7 +92,7 @@ def eval_model(model, config, class_names, valid_dt): # Load the model with the new layers to finetune detr = build_model(config) - valid_dt, class_names = load_coco_dataset(config, 1, augmentation=False, shuffle=False) + valid_dt, class_names = load_coco_dataset(config, 1, augmentation=False, shuffle_data=False) # Run training eval_model(detr, config, class_names, valid_dt) diff --git a/webcam_inference.py b/webcam_inference.py index 247201bf..2e26109d 100644 --- a/webcam_inference.py +++ b/webcam_inference.py @@ -58,15 +58,24 @@ def run_webcam_inference(model, use_mask=True): tf.config.experimental.set_memory_growth(physical_devices[0], True) config = TrainingConfig() - args = training_config_parser().parse_args() + parser = training_config_parser() + + # Logging + parser.add_argument("model", type=str, help="One of 'detr', or 'deformable-detr'") + args = parser.parse_args() config.update_from_args(args) - # Load the model with the new layers to finetune - #detr = get_detr_model(config, include_top=True, weights="detr") - #config.background_class = 91 - - deformable_detr = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") - deformable_detr.summary() + if args.model == "detr": + print("Loading detr...") + # Load the model with the new layers to finetune + model = get_detr_model(config, include_top=True, weights="detr") + config.background_class = 91 + use_mask = True + elif args.model == "deformable-detr": + print("Loading deformable-detr...") + model = get_deformable_detr_model(config, include_top=True, weights="deformable-detr") + model.summary() + use_mask = False # Run webcam inference - run_webcam_inference(deformable_detr, use_mask=False) + run_webcam_inference(model, use_mask=use_mask) From 7c0fa70ebdb741f397252a30565d5dd368caeef8 Mon Sep 17 00:00:00 2001 From: Thibault Date: Fri, 18 Jun 2021 08:55:31 +0200 Subject: [PATCH 5/5] Replace batch norm by group norm in deformable --- detr_tf/networks/deformable_detr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/detr_tf/networks/deformable_detr.py b/detr_tf/networks/deformable_detr.py index f9ae4108..9bc51c00 100644 --- a/detr_tf/networks/deformable_detr.py +++ b/detr_tf/networks/deformable_detr.py @@ -93,7 +93,7 @@ def __init__(self, ) self.model_dim = model_dim - layer_norm = tf.keras.layers.BatchNormalization + layer_norm = functools.partial(tfa.layers.GroupNormalization, groups=32, epsilon=1e-05) #tf.keras.layers.BatchNormalization self.input_proj_0 = tf.keras.layers.Conv2D(self.model_dim, kernel_size=1, name='input_proj/0/0', trainable=train_encoder) self.input_proj_gn_0 = layer_norm(name="input_proj_gn/0/1", trainable=train_encoder) @@ -237,7 +237,6 @@ def get_deformable_detr_model( include_top=False, include_bbox=True, - nb_keypoint=None, nb_class=None, weights=None, tf_backbone=False,