-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_preprocessing.py
33 lines (27 loc) · 1.48 KB
/
image_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#Image preprocessing functions
import numpy as np
from torchvision import datasets, transforms
from PIL import Image
# Define transforms for the training and testing sets
data_transforms = {'train': transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]),
'test': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])}
#Convert the image to numby array
def process_image(image):
img = Image.open(image)
pre_img = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
img = pre_img(img)
np_image = np.array(img)
return(np_image)