-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
111 lines (91 loc) · 3.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet50
from dataset.coco import COCODataset
from trainer.train_and_validate import train_and_validate
from predictor.predict_and_test import predict_and_test
from transforms.transforms import *
from utils import collate_fn
def main():
root = "./my_coco_subset" # Path to the COCO subset dataset
targets = [
"bird",
"cat",
"dog",
"horse",
"sheep",
] # The categories to be classified
target_map = {
target: idx for idx, target in enumerate(targets)
} # Map targets to numeric types
# Create datasets for each category
datasets = [
COCODataset(root, target, transform=initial_transform()) for target in targets
]
# Assign numeric targets
for dataset, target in zip(datasets, targets):
dataset.target = target_map[target]
# Combine all datasets
combined_dataset = ConcatDataset(datasets)
# Split the dataset into train, val, test sets
data_len = len(combined_dataset)
train_size = int(0.8 * data_len)
val_size = int(0.1 * data_len)
test_size = data_len - train_size - val_size
train_set, val_set, test_set = torch.utils.data.random_split(
combined_dataset, [train_size, val_size, test_size]
)
# Adjust transforms
train_set.transform = train_transform()
val_set.transform = val_transform()
test_set.transform = test_transform()
# Create dataloaders
train_loader = DataLoader(
train_set, batch_size=32, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_set, batch_size=32, shuffle=False, collate_fn=collate_fn
)
test_loader = DataLoader(
test_set, batch_size=32, shuffle=False, collate_fn=collate_fn
)
# Initialize the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and modify the last layer
model = resnet50(weights="DEFAULT")
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
# Add dropout layer with 50% probability
nn.Dropout(0.5),
# Add a linear layer in order to deal with 5 classes
nn.Linear(num_ftrs, len(targets)),
)
# print(model)
model.to(device)
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-5
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=1e-3
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
epochs = 350
writer = SummaryWriter("logs")
# Train and cross-validate
train_and_validate(
model,
device,
train_loader,
val_loader,
loss_fn,
optimizer,
scheduler,
epochs,
writer,
)
# Predict and test
predict_and_test(model, test_loader, loss_fn, epochs, device, writer)
writer.close()
if __name__ == "__main__":
main()