-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsiamese_dataset.py
68 lines (52 loc) · 1.71 KB
/
siamese_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#-*- coding:utf-8 -*-
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import sys
import time
import pandas as pd
def process_single_img(img_path, should_invert, transform):
img0 = Image.open(img_path)
#图像灰度化
img0 = img0.convert("L")#gray
#颜色反转
if should_invert:
img0 = PIL.ImageOps.invert(img0)
#图像各种变换
if transform is not None:
img0 = transform(img0)
return img0
class SiameseNetworkDataset(Dataset):
def __init__(self, sample_path, transform=None, should_invert=True):
self.sample_path = sample_path
self.txt_context = []
file_d = open(self.sample_path, "r")
lines = file_d.readlines()
for ln in lines:
self.txt_context.append(ln)
self.num_ = len(lines)
file_d.close()
self.transform = transform
self.should_invert = should_invert
def __getitem__(self,index):
line = self.txt_context[index]
list_name = line.split('\n')[0].split(' ')
path_0 = list_name[0]
path_1 = list_name[1]
label_c = int(list_name[2])
img0 = process_single_img(path_0, self.should_invert, self.transform)
img1 = process_single_img(path_1, self.should_invert, self.transform)
return img0, img1 , torch.from_numpy(np.array([label_c],dtype=np.float32))
def __len__(self):
return self.num_