-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
109 lines (80 loc) · 2.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/python
import torch
def transform(input, randlist_all):
flipprob = 0.5
flipproblr = 0.5
rot = 0.5
partstup = torch.split(input, 1)
parts = []
for i in range(len(partstup)):
parts.append(partstup[i])
for i in range(len(parts)):
randlist = randlist_all[i]
if randlist[0] < 10*flipprob:
parts[i] = torch.flip(parts[i], [2])
else:
parts[i] = parts[i]
if randlist[1] < 10*flipproblr:
parts[i] = torch.flip(parts[i], [3])
else:
parts[i] = parts[i]
if randlist[2] < 10*rot:
parts[i] = torch.rot90(parts[i], 1, [2, 3])
else:
parts[i] = parts[i]
if randlist[3] < 10*rot:
parts[i] = torch.rot90(parts[i], 3, [2, 3])
else:
parts[i] = parts[i]
if randlist[4] < 10*rot:
parts[i] = torch.rot90(parts[i], 2, [2, 3])
else:
parts[i] = parts[i]
target = torch.cat(parts)
return target
def backtransform(input, randlist_all):
flipprob = 0.5
flipproblr = 0.5
rot = 0.8
partstup = torch.split(input, 1)
parts = []
for i in range(len(partstup)):
parts.append(partstup[i])
for i in range(len(parts)):
randlist = randlist_all[i]
if randlist[0] < 10*flipprob:
parts[i] = torch.flip(parts[i], [2])
else:
parts[i] = parts[i]
if randlist[1] < 10*flipproblr:
parts[i] = torch.flip(parts[i], [3])
else:
parts[i] = parts[i]
if randlist[2] < 10*rot:
parts[i] = torch.rot90(parts[i], 3, [2, 3])
else:
parts[i] = parts[i]
if randlist[3] < 10*rot:
parts[i] = torch.rot90(parts[i], 1, [2, 3])
else:
parts[i] = parts[i]
if randlist[4] < 10*rot:
parts[i] = torch.rot90(parts[i], 2, [2, 3])
else:
parts[i] = parts[i]
target = torch.cat(parts)
return target
def get_colours():
red = [255,0,0]
green = [0,255,0]
blue = [0,0,255]
yellow = [255,255,0]
black = [0,0,0]
white = [255,255,255]
cyan = [0,255,255]
orange = [255,128,0]
pink = [255,0,255]
violett = [102,0,204]
dark_green = [0,102,0]
colours = [red, green, blue, yellow, black, white, cyan, orange, pink, violett, dark_green]
return colours