|
49 | 49 | "import glob\n",
|
50 | 50 | "from sklearn.model_selection import KFold\n",
|
51 | 51 | "import json\n",
|
52 |
| - "from collections import defaultdict\n", |
53 |
| - "from detectron2.data.datasets import register_coco_instances" |
| 52 | + "from collections import defaultdict" |
54 | 53 | ]
|
55 | 54 | },
|
56 | 55 | {
|
|
150 | 149 | " annotations_count = len(data_dict['annotations'])\n",
|
151 | 150 | " print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
|
152 | 151 | "\n",
|
| 152 | + " \n", |
| 153 | + "def unregister_coco_instances(name):\n", |
| 154 | + " if name in DatasetCatalog.list():\n", |
| 155 | + " DatasetCatalog.remove(name)\n", |
| 156 | + " MetadataCatalog.remove(name)\n", |
| 157 | + "\n", |
153 | 158 | "# Generate K-Fold cross-validation\n",
|
154 | 159 | "kf = KFold(n_splits=NUM_FOLDS)\n",
|
155 | 160 | "pairs = []\n",
|
156 | 161 | "for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
|
157 | 162 | " train_data, test_data = split_data(train_indices, test_indices)\n",
|
158 |
| - " # Register COCO instances for training and validation. \n", |
159 |
| - " # Note: The 'train2017' folder is retained as the base path for images.\n", |
160 | 163 | " train_file = f\"train_coco_{fold}_fold.json\"\n",
|
161 | 164 | " test_file = f\"test_coco_{fold}_fold.json\"\n",
|
| 165 | + " # Unregister instances with the same names only if they exist\n", |
| 166 | + " unregister_coco_instances(train_file)\n", |
| 167 | + " unregister_coco_instances(test_file)\n", |
| 168 | + " # Register COCO instances for training and validation. \n", |
| 169 | + " # Note: The 'train2017' folder is retained as the base path for images.\n", |
162 | 170 | " register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
|
163 | 171 | " register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
|
164 | 172 | " pairs.append([train_file,test_file])\n",
|
|
215 | 223 | " \n",
|
216 | 224 | " def data_loader_mapper(self, batch):\n",
|
217 | 225 | " return batch\n",
|
218 |
| - " \n", |
219 |
| - "\n", |
220 | 226 | "\n",
|
221 | 227 | " def run_hooks(self):\n",
|
222 | 228 | " val_loss = self.validation()\n",
|
|
0 commit comments