1
+ import numpy as np
2
+ from numpy .linalg import inv
3
+ import pyPyrTools as ppt
4
+ from pyPyrTools .corrDn import corrDn
5
+ import math
6
+
7
+ def vifvec (imref_batch ,imdist_batch ):
8
+ M = 3
9
+ subbands = [4 , 7 , 10 , 13 , 16 , 19 , 22 , 25 ]
10
+ sigma_nsq = 0.4
11
+
12
+ batch_num = 1
13
+ if imref_batch .ndim >= 3 :
14
+ batch_num = imref_batch .shape [0 ]
15
+
16
+ vif = np .zeros ([batch_num ,])
17
+
18
+ for a in range (batch_num ):
19
+ if batch_num > 1 :
20
+ imref = imref_batch [a ,:,:]
21
+ imdist = imdist_batch [a ,:,:]
22
+ else :
23
+ imref = imref_batch
24
+ imdist = imdist_batch
25
+
26
+ #Wavelet Decomposition
27
+ pyr = ppt .Spyr (imref , 4 , 'sp5Filters' , 'reflect1' )
28
+ org = pyr .pyr [::- 1 ] #reverse list
29
+
30
+ pyr = ppt .Spyr (imdist , 4 , 'sp5Filters' , 'reflect1' )
31
+ dist = pyr .pyr [::- 1 ]
32
+
33
+ #Calculate parameters of the distortion channel
34
+ g_all , vv_all = vif_sub_est_M (org , dist , subbands , M )
35
+
36
+ #calculate the parameters of reference
37
+ ssarr , larr , cuarr = refparams_vecgsm (org , subbands , M )
38
+
39
+ num = np .zeros ([1 ,len (subbands )])
40
+ den = np .zeros ([1 ,len (subbands )])
41
+
42
+ for i in range (len (subbands )):
43
+ sub = subbands [i ]
44
+ g = g_all [i ]
45
+ vv = vv_all [i ]
46
+ ss = ssarr [i ]
47
+ lam = larr [i ]
48
+ #cu = cuarr[i]
49
+
50
+ #neigvals = len(lam)
51
+ lev = math .ceil ((sub - 1 )/ 6 )
52
+ winsize = 2 ** lev + 1
53
+ offset = (winsize - 1 )/ 2
54
+ offset = math .ceil (offset / M )
55
+
56
+ g = g [offset :g .shape [0 ]- offset ,offset :g .shape [1 ]- offset ]
57
+ vv = vv [offset :vv .shape [0 ]- offset ,offset :vv .shape [1 ]- offset ]
58
+ ss = ss [offset :ss .shape [0 ]- offset ,offset :ss .shape [1 ]- offset ]
59
+
60
+ temp1 ,temp2 = 0 ,0
61
+ rt = []
62
+ for j in range (len (lam )):
63
+ temp1 += np .sum (np .log2 (1 + np .divide (np .multiply (np .multiply (g ,g ),ss ) * lam [j ], vv + sigma_nsq ))) #distorted image information
64
+ temp2 += np .sum (np .log2 (1 + np .divide (ss * lam [j ], sigma_nsq ))) #reference image information
65
+ rt .append (np .sum (np .log (1 + np .divide (ss * lam [j ], sigma_nsq ))))
66
+
67
+ num [0 ,i ] = temp1
68
+ den [0 ,i ] = temp2
69
+
70
+ vif [a ] = np .sum (num )/ np .sum (den )
71
+ print (vif )
72
+ return vif
73
+
74
+
75
+ def vif_sub_est_M (org , dist , subbands , M ):
76
+ tol = 1e-15 #tolerance for zero variance
77
+ g_all = []
78
+ vv_all = []
79
+
80
+ for i in range (len (subbands )):
81
+ sub = subbands [i ]
82
+ y = org [sub - 1 ]
83
+ yn = dist [sub - 1 ]
84
+
85
+ #size of window used in distortion channel estimation
86
+ lev = math .ceil ((sub - 1 )/ 6 )
87
+ winsize = 2 ** lev + 1
88
+ win = np .ones ([winsize , winsize ])
89
+
90
+ #force subband to be a multiple of M
91
+ newsize = [math .floor (y .shape [0 ]/ M ) * M , math .floor (y .shape [1 ]/ M ) * M ]
92
+ y = y [:newsize [0 ],:newsize [1 ]]
93
+ yn = yn [:newsize [0 ],:newsize [1 ]]
94
+
95
+ #correlation with downsampling
96
+ winstep = (M , M )
97
+ winstart = (math .floor (M / 2 ) ,math .floor (M / 2 ))
98
+ winstop = (y .shape [0 ] - math .ceil (M / 2 ) + 1 , y .shape [1 ] - math .ceil (M / 2 ) + 1 )
99
+
100
+ #mean
101
+ mean_x = corrDn (y , win / np .sum (win ), 'reflect1' , winstep , winstart , winstop )
102
+ mean_y = corrDn (yn , win / np .sum (win ), 'reflect1' , winstep , winstart , winstop )
103
+
104
+ #covariance
105
+ cov_xy = corrDn (np .multiply (y , yn ), win , 'reflect1' , winstep , winstart , winstop ) - \
106
+ np .sum (win ) * np .multiply (mean_x ,mean_y )
107
+
108
+ #variance
109
+ ss_x = corrDn (np .multiply (y ,y ), win , 'reflect1' , winstep , winstart , winstop ) - np .sum (win ) * np .multiply (mean_x ,mean_x )
110
+ ss_y = corrDn (np .multiply (yn ,yn ), win , 'reflect1' , winstep , winstart , winstop ) - np .sum (win ) * np .multiply (mean_y , mean_y )
111
+
112
+ ss_x [np .where (ss_x < 0 )] = 0
113
+ ss_y [np .where (ss_y < 0 )] = 0
114
+
115
+ #Regression
116
+ g = np .divide (cov_xy ,(ss_x + tol ))
117
+
118
+ vv = (ss_y - np .multiply (g , cov_xy ))/ (np .sum (win ))
119
+
120
+ g [np .where (ss_x < tol )] = 0
121
+ vv [np .where (ss_x < tol )] = ss_y [np .where (ss_x < tol )]
122
+ ss_x [np .where (ss_x < tol )] = 0
123
+
124
+ g [np .where (ss_y < tol )] = 0
125
+ vv [np .where (ss_y < tol )] = 0
126
+
127
+ g [np .where (g < 0 )] = 0
128
+ vv [np .where (g < 0 )] = ss_y [np .where (g < 0 )]
129
+
130
+ vv [np .where (vv <= tol )] = tol
131
+
132
+ g_all .append (g )
133
+ vv_all .append (vv )
134
+
135
+ return g_all , vv_all
136
+
137
+ def refparams_vecgsm (org , subbands , M ):
138
+ # This function caluclates the parameters of the reference image
139
+ #l_arr = np.zeros([subbands[-1],M**2])
140
+ l_arr , ssarr , cu_arr = [],[],[]
141
+ for i in range (len (subbands )):
142
+ sub = subbands [i ]
143
+ y = org [sub - 1 ]
144
+
145
+ sizey = (math .floor (y .shape [0 ]/ M )* M , math .floor (y .shape [1 ]/ M )* M )
146
+ y = y [:sizey [0 ],:sizey [1 ]]
147
+
148
+ #Collect MxM blocks, rearrange into M^2 dimensional vector
149
+ temp = []
150
+ for j in range (M ):
151
+ for k in range (M ):
152
+ temp .append (y [k :y .shape [0 ]- M + k + 1 ,j :y .shape [1 ]- M + j + 1 ].T .reshape (- 1 ))
153
+
154
+ temp = np .asarray (temp )
155
+ mcu = np .mean (temp , axis = 1 ).reshape (temp .shape [0 ],1 )
156
+ mean_sub = temp - np .repeat (mcu ,temp .shape [1 ],axis = 1 )
157
+ cu = mean_sub @ mean_sub .T / temp .shape [1 ]
158
+ #Calculate S field, non-overlapping blocks
159
+ temp = []
160
+ for j in range (M ):
161
+ for k in range (M ):
162
+ temp .append (y [k ::M ,j ::M ].T .reshape (- 1 ))
163
+
164
+ temp = np .asarray (temp )
165
+ ss = inv (cu ) @ temp
166
+ ss = np .sum (np .multiply (ss ,temp ),axis = 0 )/ (M ** 2 )
167
+ ss = ss .reshape (int (sizey [1 ]/ M ), int (sizey [0 ]/ M )).T
168
+
169
+ d , _ = np .linalg .eig (cu )
170
+ l_arr .append (d )
171
+ ssarr .append (ss )
172
+ cu_arr .append (cu )
173
+
174
+ return ssarr , l_arr , cu_arr
0 commit comments