|
| 1 | +''' |
| 2 | +A Moduele which binds Yolov7 repo with Deepsort with modifications |
| 3 | +''' |
| 4 | + |
| 5 | +import os |
| 6 | +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # comment out below line to enable tensorflow logging outputs |
| 7 | +import time |
| 8 | +import tensorflow as tf |
| 9 | + |
| 10 | +physical_devices = tf.config.experimental.list_physical_devices('GPU') |
| 11 | +if len(physical_devices) > 0: |
| 12 | + tf.config.experimental.set_memory_growth(physical_devices[0], True) |
| 13 | + |
| 14 | +import cv2 |
| 15 | +import numpy as np |
| 16 | +import matplotlib.pyplot as plt |
| 17 | + |
| 18 | +from tensorflow.compat.v1 import ConfigProto # DeepSORT official implementation uses tf1.x so we have to do some modifications to avoid errors |
| 19 | + |
| 20 | +# deep sort imports |
| 21 | +from deep_sort import preprocessing, nn_matching |
| 22 | +from deep_sort.detection import Detection |
| 23 | +from deep_sort.tracker import Tracker |
| 24 | + |
| 25 | +# import from helpers |
| 26 | +from tracking_helpers import read_class_names, create_box_encoder |
| 27 | +from detection_helpers import * |
| 28 | + |
| 29 | + |
| 30 | + # load configuration for object detector |
| 31 | +config = ConfigProto() |
| 32 | +config.gpu_options.allow_growth = True |
| 33 | + |
| 34 | + |
| 35 | + |
| 36 | +class YOLOv7_DeepSORT: |
| 37 | + ''' |
| 38 | + Class to Wrap ANY detector of YOLO type with DeepSORT |
| 39 | + ''' |
| 40 | + def __init__(self, reID_model_path:str, detector, max_cosine_distance:float=0.4, nn_budget:float=None, nms_max_overlap:float=1.0, |
| 41 | + coco_names_path:str ="./io_data/input/classes/coco.names", ): |
| 42 | + ''' |
| 43 | + args: |
| 44 | + reID_model_path: Path of the model which uses generates the embeddings for the cropped area for Re identification |
| 45 | + detector: object of YOLO models or any model which gives you detections as [x1,y1,x2,y2,scores, class] |
| 46 | + max_cosine_distance: Cosine Distance threshold for "SAME" person matching |
| 47 | + nn_budget: If not None, fix samples per class to at most this number. Removes the oldest samples when the budget is reached. |
| 48 | + nms_max_overlap: Maximum NMs allowed for the tracker |
| 49 | + coco_file_path: File wich contains the path to coco naames |
| 50 | + ''' |
| 51 | + self.detector = detector |
| 52 | + self.coco_names_path = coco_names_path |
| 53 | + self.nms_max_overlap = nms_max_overlap |
| 54 | + self.class_names = read_class_names() |
| 55 | + |
| 56 | + # initialize deep sort |
| 57 | + self.encoder = create_box_encoder(reID_model_path, batch_size=1) |
| 58 | + metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) # calculate cosine distance metric |
| 59 | + self.tracker = Tracker(metric) # initialize tracker |
| 60 | + |
| 61 | + |
| 62 | + def track_video(self,video:str, output:str, skip_frames:int=0, show_live:bool=False, count_objects:bool=False, verbose:int = 0): |
| 63 | + ''' |
| 64 | + Track any given webcam or video |
| 65 | + args: |
| 66 | + video: path to input video or set to 0 for webcam |
| 67 | + output: path to output video |
| 68 | + skip_frames: Skip every nth frame. After saving the video, it'll have very visuals experience due to skipped frames |
| 69 | + show_live: Whether to show live video tracking. Press the key 'q' to quit |
| 70 | + count_objects: count objects being tracked on screen |
| 71 | + verbose: print details on the screen allowed values 0,1,2 |
| 72 | + ''' |
| 73 | + try: # begin video capture |
| 74 | + vid = cv2.VideoCapture(int(video)) |
| 75 | + except: |
| 76 | + vid = cv2.VideoCapture(video) |
| 77 | + |
| 78 | + out = None |
| 79 | + if output: # get video ready to save locally if flag is set |
| 80 | + width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) # by default VideoCapture returns float instead of int |
| 81 | + height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| 82 | + fps = int(vid.get(cv2.CAP_PROP_FPS)) |
| 83 | + codec = cv2.VideoWriter_fourcc(*"XVID") |
| 84 | + out = cv2.VideoWriter(output, codec, fps, (width, height)) |
| 85 | + |
| 86 | + frame_num = 0 |
| 87 | + while True: # while video is running |
| 88 | + return_value, frame = vid.read() |
| 89 | + if not return_value: |
| 90 | + print('Video has ended or failed!') |
| 91 | + break |
| 92 | + frame_num +=1 |
| 93 | + |
| 94 | + if skip_frames and not frame_num % skip_frames: continue # skip every nth frame. When every frame is not important, you can use this to fasten the process |
| 95 | + if verbose >= 1:start_time = time.time() |
| 96 | + |
| 97 | + # -----------------------------------------PUT ANY DETECTION MODEL HERE ----------------------------------------------------------------- |
| 98 | + yolo_dets = self.detector.detect(frame.copy(), plot_bb = False) # Get the detections |
| 99 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| 100 | + |
| 101 | + bboxes = yolo_dets[:,:4] |
| 102 | + bboxes[:,2] = bboxes[:,2] - bboxes[:,0] # convert from xyxy to xywh |
| 103 | + bboxes[:,3] = bboxes[:,3] - bboxes[:,1] |
| 104 | + |
| 105 | + scores = yolo_dets[:,4] |
| 106 | + classes = yolo_dets[:,-1] |
| 107 | + num_objects = bboxes.shape[0] |
| 108 | + # ---------------------------------------- DETECTION PART COMPLETED --------------------------------------------------------------------- |
| 109 | + |
| 110 | + names = [] |
| 111 | + for i in range(num_objects): # loop through objects and use class index to get class name |
| 112 | + class_indx = int(classes[i]) |
| 113 | + class_name = self.class_names[class_indx] |
| 114 | + names.append(class_name) |
| 115 | + |
| 116 | + names = np.array(names) |
| 117 | + count = len(names) |
| 118 | + |
| 119 | + if count_objects: |
| 120 | + cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1.5, (0, 0, 0), 2) |
| 121 | + |
| 122 | + # ---------------------------------- DeepSORT tacker work starts here ------------------------------------------------------------ |
| 123 | + features = self.encoder(frame, bboxes) # encode detections and feed to tracker. [No of BB / detections per frame, embed_size] |
| 124 | + detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)] # [No of BB per frame] deep_sort.detection.Detection object |
| 125 | + |
| 126 | + cmap = plt.get_cmap('tab20b') #initialize color map |
| 127 | + colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] |
| 128 | + |
| 129 | + boxs = np.array([d.tlwh for d in detections]) # run non-maxima supression below |
| 130 | + scores = np.array([d.confidence for d in detections]) |
| 131 | + classes = np.array([d.class_name for d in detections]) |
| 132 | + indices = preprocessing.non_max_suppression(boxs, classes, self.nms_max_overlap, scores) |
| 133 | + detections = [detections[i] for i in indices] |
| 134 | + |
| 135 | + self.tracker.predict() # Call the tracker |
| 136 | + self.tracker.update(detections) # updtate using Kalman Gain |
| 137 | + |
| 138 | + for track in self.tracker.tracks: # update new findings AKA tracks |
| 139 | + if not track.is_confirmed() or track.time_since_update > 1: |
| 140 | + continue |
| 141 | + bbox = track.to_tlbr() |
| 142 | + class_name = track.get_class() |
| 143 | + |
| 144 | + color = colors[int(track.track_id) % len(colors)] # draw bbox on screen |
| 145 | + color = [i * 255 for i in color] |
| 146 | + cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2) |
| 147 | + cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1) |
| 148 | + cv2.putText(frame, class_name + " : " + str(track.track_id),(int(bbox[0]), int(bbox[1]-11)),0, 0.6, (255,255,255),1, lineType=cv2.LINE_AA) |
| 149 | + |
| 150 | + if verbose == 2: |
| 151 | + print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))) |
| 152 | + |
| 153 | + # -------------------------------- Tracker work ENDS here ----------------------------------------------------------------------- |
| 154 | + if verbose >= 1: |
| 155 | + fps = 1.0 / (time.time() - start_time) # calculate frames per second of running detections |
| 156 | + if not count_objects: print(f"Processed frame no: {frame_num} || Current FPS: {round(fps,2)}") |
| 157 | + else: print(f"Processed frame no: {frame_num} || Current FPS: {round(fps,2)} || Objects tracked: {count}") |
| 158 | + |
| 159 | + result = np.asarray(frame) |
| 160 | + result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
| 161 | + |
| 162 | + if output: out.write(result) # save output video |
| 163 | + |
| 164 | + if show_live: |
| 165 | + cv2.imshow("Output Video", result) |
| 166 | + if cv2.waitKey(1) & 0xFF == ord('q'): break |
| 167 | + |
| 168 | + cv2.destroyAllWindows() |
0 commit comments