Skip to content

Commit 4a2e1da

Browse files
committed
unregister if already registered
1 parent 43794dc commit 4a2e1da

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

object_detection/detectron2_training-kfold.ipynb

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@
4949
"import glob\n",
5050
"from sklearn.model_selection import KFold\n",
5151
"import json\n",
52-
"from collections import defaultdict\n",
53-
"from detectron2.data.datasets import register_coco_instances"
52+
"from collections import defaultdict"
5453
]
5554
},
5655
{
@@ -150,15 +149,24 @@
150149
" annotations_count = len(data_dict['annotations'])\n",
151150
" print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
152151
"\n",
152+
" \n",
153+
"def unregister_coco_instances(name):\n",
154+
" if name in DatasetCatalog.list():\n",
155+
" DatasetCatalog.remove(name)\n",
156+
" MetadataCatalog.remove(name)\n",
157+
"\n",
153158
"# Generate K-Fold cross-validation\n",
154159
"kf = KFold(n_splits=NUM_FOLDS)\n",
155160
"pairs = []\n",
156161
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
157162
" train_data, test_data = split_data(train_indices, test_indices)\n",
158-
" # Register COCO instances for training and validation. \n",
159-
" # Note: The 'train2017' folder is retained as the base path for images.\n",
160163
" train_file = f\"train_coco_{fold}_fold.json\"\n",
161164
" test_file = f\"test_coco_{fold}_fold.json\"\n",
165+
" # Unregister instances with the same names only if they exist\n",
166+
" unregister_coco_instances(train_file)\n",
167+
" unregister_coco_instances(test_file)\n",
168+
" # Register COCO instances for training and validation. \n",
169+
" # Note: The 'train2017' folder is retained as the base path for images.\n",
162170
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
163171
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
164172
" pairs.append([train_file,test_file])\n",
@@ -215,8 +223,6 @@
215223
" \n",
216224
" def data_loader_mapper(self, batch):\n",
217225
" return batch\n",
218-
" \n",
219-
"\n",
220226
"\n",
221227
" def run_hooks(self):\n",
222228
" val_loss = self.validation()\n",

0 commit comments

Comments
 (0)