@@ -27,6 +27,7 @@ import (
27
27
ss "github.com/Jigsaw-Code/outline-ss-server/shadowsocks"
28
28
logging "github.com/op/go-logging"
29
29
"github.com/shadowsocks/go-shadowsocks2/socks"
30
+ "github.com/stretchr/testify/assert"
30
31
)
31
32
32
33
const timeout = 5 * time .Minute
@@ -89,10 +90,16 @@ func (conn *fakePacketConn) Close() error {
89
90
return nil
90
91
}
91
92
93
+ type udpReport struct {
94
+ clientLocation , accessKey , status string
95
+ clientProxyBytes , proxyTargetBytes int
96
+ }
97
+
92
98
// Stub metrics implementation for testing NAT behaviors.
93
99
type natTestMetrics struct {
94
100
metrics.ShadowsocksMetrics
95
101
natEntriesAdded int
102
+ upstreamPackets []udpReport
96
103
}
97
104
98
105
func (m * natTestMetrics ) AddTCPProbe (clientLocation , status , drainResult string , port int , data metrics.ProxyMetrics ) {
@@ -107,6 +114,7 @@ func (m *natTestMetrics) SetNumAccessKeys(numKeys int, numPorts int) {
107
114
func (m * natTestMetrics ) AddOpenTCPConnection (clientLocation string ) {
108
115
}
109
116
func (m * natTestMetrics ) AddUDPPacketFromClient (clientLocation , accessKey , status string , clientProxyBytes , proxyTargetBytes int , timeToCipher time.Duration ) {
117
+ m .upstreamPackets = append (m .upstreamPackets , udpReport {clientLocation , accessKey , status , clientProxyBytes , proxyTargetBytes })
110
118
}
111
119
func (m * natTestMetrics ) AddUDPPacketFromTarget (clientLocation , accessKey , status string , targetProxyBytes , proxyClientBytes int ) {
112
120
}
@@ -115,21 +123,20 @@ func (m *natTestMetrics) AddUDPNatEntry() {
115
123
}
116
124
func (m * natTestMetrics ) RemoveUDPNatEntry () {}
117
125
118
- func TestIPFilter (t * testing.T ) {
119
- // Takes a validation policy, and returns the metrics it
120
- // generates when localhost access is attempted
121
- checkLocalhost := func (validator onet.TargetIPValidator ) * natTestMetrics {
122
- ciphers , _ := MakeTestCiphers ([]string {"asdf" })
123
- cipher := ciphers .SnapshotForClientIP (nil )[0 ].Value .(* CipherEntry ).Cipher
124
- clientConn := makePacketConn ()
125
- metrics := & natTestMetrics {}
126
- service := NewUDPService (timeout , ciphers , metrics )
127
- service .SetTargetIPValidator (validator )
128
- go service .Serve (clientConn )
129
-
130
- // Send one packet to the "discard" port on localhost
131
- targetAddr := socks .ParseAddr ("127.0.0.1:9" )
132
- payload := []byte ("payload" )
126
+ // Takes a validation policy, and returns the metrics it
127
+ // generates when localhost access is attempted
128
+ func sendToDiscard (payloads [][]byte , validator onet.TargetIPValidator ) * natTestMetrics {
129
+ ciphers , _ := MakeTestCiphers ([]string {"asdf" })
130
+ cipher := ciphers .SnapshotForClientIP (nil )[0 ].Value .(* CipherEntry ).Cipher
131
+ clientConn := makePacketConn ()
132
+ metrics := & natTestMetrics {}
133
+ service := NewUDPService (timeout , ciphers , metrics )
134
+ service .SetTargetIPValidator (validator )
135
+ go service .Serve (clientConn )
136
+
137
+ // Send one packet to the "discard" port on localhost
138
+ targetAddr := socks .ParseAddr ("127.0.0.1:9" )
139
+ for _ , payload := range payloads {
133
140
plaintext := append (targetAddr , payload ... )
134
141
ciphertext := make ([]byte , cipher .SaltSize ()+ len (plaintext )+ cipher .TagSize ())
135
142
ss .Pack (ciphertext , plaintext , cipher )
@@ -140,26 +147,52 @@ func TestIPFilter(t *testing.T) {
140
147
},
141
148
payload : ciphertext ,
142
149
}
143
-
144
- service .GracefulStop ()
145
- return metrics
146
150
}
147
151
152
+ service .GracefulStop ()
153
+ return metrics
154
+ }
155
+
156
+ func TestIPFilter (t * testing.T ) {
157
+ // Test both the first-packet and subsequent-packet cases.
158
+ payloads := [][]byte {[]byte ("payload1" ), []byte ("payload2" )}
159
+
148
160
t .Run ("Localhost allowed" , func (t * testing.T ) {
149
- metrics := checkLocalhost (allowAll )
150
- if metrics .natEntriesAdded != 1 {
151
- t .Errorf ("Expected 1 NAT entry, not %d" , metrics .natEntriesAdded )
152
- }
161
+ metrics := sendToDiscard (payloads , allowAll )
162
+ assert .Equal (t , metrics .natEntriesAdded , 1 , "Expected 1 NAT entry, not %d" , metrics .natEntriesAdded )
153
163
})
154
164
155
165
t .Run ("Localhost not allowed" , func (t * testing.T ) {
156
- metrics := checkLocalhost (onet .RequirePublicIP )
157
- if metrics .natEntriesAdded != 0 {
158
- t .Error ("Unexpected NAT entry on rejected packet" )
166
+ metrics := sendToDiscard (payloads , onet .RequirePublicIP )
167
+ assert .Equal (t , 0 , metrics .natEntriesAdded , "Unexpected NAT entry on rejected packet" )
168
+ assert .Equal (t , 2 , len (metrics .upstreamPackets ), "Expected 2 reports, not %v" , metrics .upstreamPackets )
169
+ for _ , report := range metrics .upstreamPackets {
170
+ assert .Greater (t , report .clientProxyBytes , 0 , "Expected nonzero input packet size" )
171
+ assert .Equal (t , 0 , report .proxyTargetBytes , "No bytes should be sent due to a disallowed packet" )
172
+ assert .Equal (t , report .accessKey , "id-0" , "Unexpected access key: %s" , report .accessKey )
159
173
}
160
174
})
161
175
}
162
176
177
+ func TestUpstreamMetrics (t * testing.T ) {
178
+ // Test both the first-packet and subsequent-packet cases.
179
+ const N = 10
180
+ payloads := make ([][]byte , 0 )
181
+ for i := 1 ; i <= N ; i ++ {
182
+ payloads = append (payloads , make ([]byte , i ))
183
+ }
184
+
185
+ metrics := sendToDiscard (payloads , allowAll )
186
+
187
+ assert .Equal (t , N , len (metrics .upstreamPackets ), "Expected %d reports, not %v" , N , metrics .upstreamPackets )
188
+ for i , report := range metrics .upstreamPackets {
189
+ assert .Equal (t , i + 1 , report .proxyTargetBytes , "Expected %d payload bytes, not %d" , i + 1 , report .proxyTargetBytes )
190
+ assert .Greater (t , report .clientProxyBytes , report .proxyTargetBytes , "Expected nonzero input overhead (%d > %d)" , report .clientProxyBytes , report .proxyTargetBytes )
191
+ assert .Equal (t , "id-0" , report .accessKey , "Unexpected access key name: %s" , report .accessKey )
192
+ assert .Equal (t , "OK" , report .status , "Wrong status: %s" , report .status )
193
+ }
194
+ }
195
+
163
196
func assertAlmostEqual (t * testing.T , a , b time.Time ) {
164
197
delta := a .Sub (b )
165
198
limit := 100 * time .Millisecond
0 commit comments