Skip to content

Commit 3643335

Browse files
authored
Merge pull request #87 from cleanlab/kfold_early
add early stopping support for object detection example
2 parents f85155b + 704a473 commit 3643335

File tree

1 file changed

+103
-26
lines changed

1 file changed

+103
-26
lines changed

object_detection/detectron2_training-kfold.ipynb

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
{
2626
"cell_type": "code",
2727
"execution_count": null,
28-
"metadata": {},
28+
"metadata": {
29+
"tags": []
30+
},
2931
"outputs": [],
3032
"source": [
3133
"from detectron2.engine import DefaultTrainer\n",
3234
"from detectron2.config import get_cfg\n",
3335
"import pickle\n",
3436
"# import some common libraries\n",
37+
"from detectron2.data import build_detection_test_loader, build_detection_train_loader\n",
3538
"import numpy as np\n",
3639
"import os, json, cv2, random\n",
3740
"from detectron2.data import build_detection_test_loader\n",
@@ -52,12 +55,15 @@
5255
{
5356
"cell_type": "code",
5457
"execution_count": null,
55-
"metadata": {},
58+
"metadata": {
59+
"tags": []
60+
},
5661
"outputs": [],
5762
"source": [
5863
"!wget -nc \"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\" && unzip -q -o annotations_trainval2017.zip\n",
5964
"!wget -nc \"http://images.cocodataset.org/zips/val2017.zip\" && unzip -q -o val2017.zip\n",
60-
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip"
65+
"!wget -nc \"http://images.cocodataset.org/zips/train2017.zip\" && unzip -q -o train2017.zip\n",
66+
"!wget -nc \"https://cleanlab-public.s3.amazonaws.com/ObjectDetectionBenchmarking/tutorial/TRAIN_COCO_ALL_labels.pkl\""
6167
]
6268
},
6369
{
@@ -92,7 +98,9 @@
9298
{
9399
"cell_type": "code",
94100
"execution_count": null,
95-
"metadata": {},
101+
"metadata": {
102+
"tags": []
103+
},
96104
"outputs": [],
97105
"source": [
98106
"import json\n",
@@ -141,13 +149,26 @@
141149
" annotations_count = len(data_dict['annotations'])\n",
142150
" print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
143151
"\n",
152+
" \n",
153+
"def unregister_coco_instances(name):\n",
154+
" if name in DatasetCatalog.list():\n",
155+
" DatasetCatalog.remove(name)\n",
156+
" MetadataCatalog.remove(name)\n",
157+
"\n",
144158
"# Generate K-Fold cross-validation\n",
145159
"kf = KFold(n_splits=NUM_FOLDS)\n",
146160
"pairs = []\n",
147161
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
148162
" train_data, test_data = split_data(train_indices, test_indices)\n",
149163
" train_file = f\"train_coco_{fold}_fold.json\"\n",
150164
" test_file = f\"test_coco_{fold}_fold.json\"\n",
165+
" # Unregister instances with the same names only if they exist\n",
166+
" unregister_coco_instances(train_file)\n",
167+
" unregister_coco_instances(test_file)\n",
168+
" # Register COCO instances for training and validation. \n",
169+
" # Note: The 'train2017' folder is retained as the base path for images.\n",
170+
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
171+
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
151172
" pairs.append([train_file,test_file])\n",
152173
" with open(train_file, 'w') as train_file:\n",
153174
" json.dump(train_data, train_file)\n",
@@ -156,7 +177,9 @@
156177
" print(f\"Data info for training data fold {fold}:\")\n",
157178
" print_data_info(train_data, fold)\n",
158179
" print(f\"Data info for test data fold {fold}:\")\n",
159-
" print_data_info(test_data, fold)\n"
180+
" print_data_info(test_data, fold)\n",
181+
" \n",
182+
"TRAIN_PATH = os.path.join(os.getcwd(),\"train2017\")"
160183
]
161184
},
162185
{
@@ -175,36 +198,83 @@
175198
"The number of worker threads is set to 2 and the batch size is set to 2.\n",
176199
"The learning rate and maximum number of iterations are also specified. The model is initialized from the COCO-Detection model zoo and the output directory for the trained model is created. Finally, the configuration is passed to the DefaultTrainer class for training the object detection model.\n",
177200
"\n",
178-
"<strong>Note:</strong> The number of iterations was set based on [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.)"
201+
"<strong>Note:</strong> The choice of the number of iterations is informed by the incorporation of [early stopping.](https://en.wikipedia.org/wiki/Early_stopping#:~:text=In%20machine%20learning%2C%20early%20stopping,training%20data%20with%20each%20iteration.) This technique monitors the validation loss throughout training, saving the model upon improvement and halting training if no progress is observed within a defined patience period. Early stopping aims to identify an optimal model iteration, mitigating the risk of overfitting."
179202
]
180203
},
181204
{
182205
"cell_type": "code",
183206
"execution_count": null,
184207
"metadata": {
185-
"scrolled": true
208+
"scrolled": true,
209+
"tags": []
186210
},
187211
"outputs": [],
188212
"source": [
189-
"def train_data(TRAIN,VALIDATION,folder):\n",
213+
"class Early_stopping(DefaultTrainer):\n",
214+
" def __init__(self, cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\"):\n",
215+
" super().__init__(cfg)\n",
216+
" self.early_stop_patience = early_stop_patience\n",
217+
" self.model_checkpoint_path = model_checkpoint_path\n",
218+
" self.best_validation_loss = float('inf')\n",
219+
" self.current_patience = 0\n",
220+
"\n",
221+
" def build_train_loader(self, cfg):\n",
222+
" return build_detection_train_loader(cfg)\n",
223+
" \n",
224+
" def data_loader_mapper(self, batch):\n",
225+
" return batch\n",
226+
"\n",
227+
" def run_hooks(self):\n",
228+
" val_loss = self.validation()\n",
229+
" if val_loss < self.best_validation_loss:\n",
230+
" self.best_validation_loss = val_loss\n",
231+
" self.current_patience = 0\n",
232+
" self.save_checkpoint()\n",
233+
" else:\n",
234+
" self.current_patience += 1\n",
235+
" if self.current_patience >= self.early_stop_patience:\n",
236+
" self._trainer.save_checkpoint()\n",
237+
" self._trainer.has_finished = True\n",
238+
"\n",
239+
" def validation(self):\n",
240+
" # Define evaluator here\n",
241+
" evaluator = COCOEvaluator(self.cfg.DATASETS.TEST[0], self.cfg, True, output_dir=\"./output/\")\n",
242+
" val_loader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TEST[0], evaluators=[evaluator])\n",
243+
" val_results = self._trainer.test(self.cfg, self.model, evaluators=[evaluator])[0]\n",
244+
" val_loss = val_results[\"total_loss\"]\n",
245+
" return val_loss\n",
246+
"\n",
247+
" def save_checkpoint(self):\n",
248+
" checkpointer = DetectionCheckpointer(self.model)\n",
249+
" checkpointer.save(self.model_checkpoint_path)\n",
250+
" \n",
251+
"\n",
252+
"def train_model(TRAIN,VALIDATION,folder):\n",
190253
" cfg = get_cfg()\n",
191254
" MODEL = 'faster_rcnn_X_101_32x8d_FPN_3x.yaml'\n",
192255
" cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/\"+MODEL))\n",
193256
" cfg.DATASETS.TRAIN = (TRAIN,)\n",
194257
" cfg.DATASETS.TEST = (VALIDATION,)\n",
195258
" cfg.DATALOADER.NUM_WORKERS = 2\n",
196-
" cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
259+
" #Uncomment if you want to use pre-trained weights for finetuning, not recommended for K fold training\n",
260+
" # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/\"+MODEL) # Let training initialize from model zoo\n",
261+
" \n",
262+
" \n",
197263
" cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real \"batch size\" commonly known to deep learning people\n",
198-
" cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n",
199-
" cfg.SOLVER.MAX_ITER = 6000 # \n",
264+
" cfg.SOLVER.BASE_LR = 0.004 # pick a good LR\n",
200265
" cfg.SOLVER.STEPS = [] # milestones where LR is reduced, in this case there's no decay\n",
201266
" cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The \"RoIHead batch size\". \n",
202267
" cfg.MODEL.ROI_HEADS.NUM_CLASSES = 80 \n",
203-
" cfg.TEST.EVAL_PERIOD = 500\n",
268+
" cfg.TEST.EVAL_PERIOD = 15000\n",
204269
" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n",
205-
" trainer = DefaultTrainer(cfg) \n",
270+
" trainer = Early_stopping(cfg, early_stop_patience=5, model_checkpoint_path=\"model_checkpoint.pth\")\n",
271+
" # Specify evaluators during testing\n",
272+
" evaluator = COCOEvaluator(cfg.DATASETS.TEST[0], cfg, True, output_dir=\"./output/\")\n",
273+
" trainer.resume_or_load(resume=False)\n",
274+
" trainer.test(cfg, trainer.model, evaluators=[evaluator])\n",
206275
" trainer.resume_or_load(resume=False)\n",
207-
" trainer.train();\n"
276+
" trainer.train();\n",
277+
" return cfg\n"
208278
]
209279
},
210280
{
@@ -224,7 +294,9 @@
224294
{
225295
"cell_type": "code",
226296
"execution_count": null,
227-
"metadata": {},
297+
"metadata": {
298+
"tags": []
299+
},
228300
"outputs": [],
229301
"source": [
230302
"def format_detectron2_predictions(instances, num_classes):\n",
@@ -254,7 +326,7 @@
254326
" formatted_results = []\n",
255327
" for i in results:\n",
256328
" if len(i) == 0:\n",
257-
" formatted_array = np.array(i, dtype=np.float32).reshape((0, num_classes))\n",
329+
" formatted_array = np.array(i, dtype=np.float32).reshape((0, 5))\n",
258330
" else:\n",
259331
" formatted_array = np.array(i, dtype=np.float32)\n",
260332
" formatted_results.append(formatted_array)\n",
@@ -266,46 +338,51 @@
266338
"cell_type": "code",
267339
"execution_count": null,
268340
"metadata": {
269-
"scrolled": true
341+
"scrolled": true,
342+
"tags": []
270343
},
271344
"outputs": [],
272345
"source": [
273346
"for k in range(0,NUM_FOLDS):\n",
274347
" result_dict = {}\n",
275348
" train_data = pairs[k][0]\n",
276349
" val_data = pairs[k][1]\n",
277-
" train_data(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
350+
" cfg = train_model(train_data,val_data,\"COCO_TRAIN_\"+str(k)+\"_FOLD\")\n",
278351
" evaluator = COCOEvaluator(val_data, output_dir=\"output\")\n",
279352
" val_loader = build_detection_test_loader(cfg, val_data)\n",
280353
" cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, \"model_final.pth\") # path to the model we just trained\n",
281354
" cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1 # set a custom testing threshold\n",
282355
" predictor = DefaultPredictor(cfg)\n",
283-
" dataset = json.load(open(\"../\"+pairs[k][1]+'.json','rb'))\n",
284-
" for image in dat['images']:\n",
285-
" im_name = os.path.join(TRAIN_PATH, i['file_name'])\n",
356+
" dataset = json.load(open(pairs[k][1],'rb'))\n",
357+
" for image in dataset['images']:\n",
358+
" im_name = os.path.join(TRAIN_PATH, image['file_name'])\n",
286359
" im = cv2.imread(im_name)\n",
287360
" outputs = predictor(im)\n",
288-
" result_dict[im_name](format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
361+
" result_dict[im_name] = (format_detectron2_predictions(outputs[\"instances\"].to(\"cpu\"),cfg.MODEL.ROI_HEADS.NUM_CLASSES))\n",
289362
" pickle.dump(result_dict,open(\"results_fold_\"+str(k)+\".pkl\",'wb'))"
290363
]
291364
},
292365
{
293366
"cell_type": "code",
294367
"execution_count": null,
295-
"metadata": {},
368+
"metadata": {
369+
"tags": []
370+
},
296371
"outputs": [],
297372
"source": [
298373
"result_dict = {}\n",
299374
"for k in range(0,NUM_FOLDS):\n",
300375
" res_d = pickle.load(open(\"results_fold_\"+str(k)+'.pkl','rb'))\n",
301376
" for r in res_d:\n",
302-
" result_dict[r] = res_d[i]"
377+
" result_dict[r] = res_d[r]"
303378
]
304379
},
305380
{
306381
"cell_type": "code",
307382
"execution_count": null,
308-
"metadata": {},
383+
"metadata": {
384+
"tags": []
385+
},
309386
"outputs": [],
310387
"source": [
311388
"dataset = pickle.load(open(\"TRAIN_COCO_ALL_labels.pkl\",'rb'))\n",
@@ -333,7 +410,7 @@
333410
"name": "python",
334411
"nbconvert_exporter": "python",
335412
"pygments_lexer": "ipython3",
336-
"version": "3.9.12"
413+
"version": "3.11.5"
337414
}
338415
},
339416
"nbformat": 4,

0 commit comments

Comments
 (0)