Skip to content

Commit d53d75b

Browse files
committed
Added code support for adjusted Entropy and MSP scoring.
:x
1 parent abb6ff2 commit d53d75b

File tree

3 files changed

+721
-0
lines changed

3 files changed

+721
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "d61e5d70-45e1-4223-b569-7a4c9247876d",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"%load_ext autoreload\n",
11+
"%autoreload 2\n",
12+
"\n",
13+
"import sys\n",
14+
"sys.path.insert(0, \"../\")\n",
15+
"\n",
16+
"\n",
17+
"from autogluon.vision import ImagePredictor, ImageDataset\n",
18+
"import numpy as np\n",
19+
"import pandas as pd\n",
20+
"import pickle\n",
21+
"import datetime\n",
22+
"from pathlib import Path\n",
23+
"from sklearn.ensemble import IsolationForest\n",
24+
"from sklearn.model_selection import StratifiedKFold\n",
25+
"from sklearn.metrics import roc_auc_score\n",
26+
"from sklearn.model_selection import train_test_split\n",
27+
"\n",
28+
"pd.set_option('display.max_rows', None)\n",
29+
"pd.set_option('display.max_columns', None)\n",
30+
"pd.set_option('display.max_colwidth', None)"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "bc2ebf60-4338-45ce-b9ce-e0d2b5cc7f0d",
36+
"metadata": {},
37+
"source": [
38+
"## Read data"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"id": "5c9b59b4-c51c-4cdb-a958-46f227cdb5d8",
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"# path to data\n",
49+
"CIFAR_10_DATA_PATH = \"/datasets/uly/ood-data/cifar10_png/\"\n",
50+
"CIFAR_100_DATA_PATH = \"/datasets/uly/ood-data/cifar100_png/\"\n",
51+
"MNIST_DATA_PATH = \"/datasets/uly/ood-data/mnist_png/\"\n",
52+
"FASHION_MNIST_DATA_PATH = \"/datasets/uly/ood-data/fashion_mnist_png/\"\n",
53+
"\n",
54+
"# read data from root folder\n",
55+
"cifar_10_train_dataset, _, cifar_10_test_dataset = ImageDataset.from_folders(root=CIFAR_10_DATA_PATH)\n",
56+
"cifar_100_train_dataset, _, cifar_100_test_dataset = ImageDataset.from_folders(root=CIFAR_100_DATA_PATH)\n",
57+
"mnist_train_dataset, _, mnist_test_dataset = ImageDataset.from_folders(root=MNIST_DATA_PATH)\n",
58+
"fashion_mnist_train_dataset, _, fashion_mnist_test_dataset = ImageDataset.from_folders(root=FASHION_MNIST_DATA_PATH)"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"id": "cde63994-e833-4f87-93b6-e05b3c7ba479",
65+
"metadata": {},
66+
"outputs": [],
67+
"source": [
68+
"# dictionary to store data path and model\n",
69+
"\n",
70+
"data_model_dict = {\n",
71+
" \"cifar-10\": {\n",
72+
" \"train_data\": cifar_10_train_dataset,\n",
73+
" \"test_data\": cifar_10_test_dataset,\n",
74+
" },\n",
75+
" \"cifar-100\": {\n",
76+
" \"train_data\": cifar_100_train_dataset,\n",
77+
" \"test_data\": cifar_100_test_dataset,\n",
78+
" },\n",
79+
" \"mnist\": {\n",
80+
" \"train_data\": mnist_train_dataset,\n",
81+
" \"test_data\": mnist_test_dataset,\n",
82+
" },\n",
83+
" \"fashion-mnist\": {\n",
84+
" \"train_data\": fashion_mnist_train_dataset,\n",
85+
" \"test_data\": fashion_mnist_test_dataset,\n",
86+
" },\n",
87+
"}"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "8606e688",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"# Create mini train dataset for testing\n",
98+
"def get_imbalanced_dataset(dataset, fractions):\n",
99+
" assert len(fractions) == dataset['label'].nunique()\n",
100+
"\n",
101+
" imbalanced_dataset = pd.DataFrame(columns=dataset.columns)\n",
102+
" print(imbalanced_dataset)\n",
103+
" for i in range(len(fractions)):\n",
104+
" idf = dataset[dataset['label'] == i].sample(frac=fractions[i])\n",
105+
" print(f'label {i} will have {idf.shape[0]} examples')\n",
106+
" imbalanced_dataset = pd.concat([imbalanced_dataset, idf], ignore_index=True)\n",
107+
" print(f'total imbalanced dataset length {imbalanced_dataset.shape[0]}')\n",
108+
" return imbalanced_dataset\n",
109+
"\n",
110+
"### Uncomment below to create imbalanced datasets\n",
111+
"\n",
112+
"# cifar_100_num_classes = len(cifar_100_train_dataset['label'].unique())\n",
113+
"# cifar_100_distribution = [0.15] * int(cifar_100_num_classes * 0.9) + [1.] * int(cifar_100_num_classes * 0.1)\n",
114+
"# cifar_100_train_dataset = get_imbalanced_dataset(cifar_100_train_dataset, cifar_100_distribution)\n",
115+
"# cifar_10_train_dataset = get_imbalanced_dataset(cifar_10_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])\n",
116+
"# mnist_train_dataset = get_imbalanced_dataset(mnist_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])\n",
117+
"# fashion_mnist_train_dataset = get_imbalanced_dataset(fashion_mnist_train_dataset,[0.09,0.09,0.09,0.09,1.,1.,0.09,0.09,1.,1.])"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"id": "1ae79a8d-bb68-46d5-b4b9-1f082da7d695",
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"# Check out a dataset\n",
128+
"mnist_train_dataset.head()"
129+
]
130+
},
131+
{
132+
"cell_type": "markdown",
133+
"id": "cc26ea6d-954c-4810-a561-50badcdd992d",
134+
"metadata": {},
135+
"source": [
136+
"## Train model"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"id": "abfa0bb0-aa32-47ac-a453-9ac5a2d91c96",
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"%%time\n",
147+
"\n",
148+
"def train_ag_model(\n",
149+
" train_data,\n",
150+
" dataset_name,\n",
151+
" model_folder=\"./models/\", \n",
152+
" epochs=100,\n",
153+
" model=\"swin_base_patch4_window7_224\",\n",
154+
" time_limit=10*3600\n",
155+
"):\n",
156+
"\n",
157+
" # init model\n",
158+
" predictor = ImagePredictor(verbosity=0)\n",
159+
"\n",
160+
" MODEL_PARAMS = {\n",
161+
" \"model\": model,\n",
162+
" \"epochs\": epochs,\n",
163+
" }\n",
164+
"\n",
165+
" # run training\n",
166+
" predictor.fit(\n",
167+
" train_data=train_data,\n",
168+
" # tuning_data=,\n",
169+
" ngpus_per_trial=1,\n",
170+
" hyperparameters=MODEL_PARAMS,\n",
171+
" time_limit=time_limit,\n",
172+
" random_state=123,\n",
173+
" )\n",
174+
"\n",
175+
" # save model\n",
176+
" filename = f\"{model_folder}{model}_{dataset_name}.ag\"\n",
177+
" predictor.save(filename) \n",
178+
" \n",
179+
" return predictor"
180+
]
181+
},
182+
{
183+
"cell_type": "markdown",
184+
"id": "a2a4cfa4-f028-4236-a15d-e3d6e7df9f20",
185+
"metadata": {},
186+
"source": [
187+
"## Train model for all datasets"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"id": "8bd6e11c-6856-4a4d-80b7-01b5635e5ffb",
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"model = \"swin_base_patch4_window7_224\"\n",
198+
"\n",
199+
"for key, data in data_model_dict.items():\n",
200+
"\n",
201+
" dataset = key\n",
202+
" train_dataset = data[\"train_data\"]\n",
203+
" \n",
204+
" print(f\"Dataset: {dataset}\")\n",
205+
" print(f\" Records: {train_dataset.shape}\")\n",
206+
" print(f\" Classes: {train_dataset.label.nunique()}\") \n",
207+
" \n",
208+
" _ = train_ag_model(train_dataset, dataset_name=dataset, model=model, epochs=100)"
209+
]
210+
}
211+
],
212+
"metadata": {
213+
"kernelspec": {
214+
"display_name": "Python 3 (ipykernel)",
215+
"language": "python",
216+
"name": "python3"
217+
},
218+
"language_info": {
219+
"codemirror_mode": {
220+
"name": "ipython",
221+
"version": 3
222+
},
223+
"file_extension": ".py",
224+
"mimetype": "text/x-python",
225+
"name": "python",
226+
"nbconvert_exporter": "python",
227+
"pygments_lexer": "ipython3",
228+
"version": "3.8.10"
229+
}
230+
},
231+
"nbformat": 4,
232+
"nbformat_minor": 5
233+
}

0 commit comments

Comments
 (0)