@@ -14,10 +14,13 @@ import (
14
14
"syscall"
15
15
16
16
"encoding/json"
17
- "fmt"
18
17
"net"
19
18
"net/http"
20
19
20
+ "strings"
21
+
22
+ "fmt"
23
+
21
24
"github.com/miekg/dns"
22
25
"github.com/wrouesnel/go.log"
23
26
)
28
31
defaultServer = flag .String ("default" , "https://dns.google.com/resolve" ,
29
32
"DNS-over-HTTPS service endpoint" )
30
33
34
+ prefixServer = flag .String ("primary-dns" , "" ,
35
+ "If set all DNS queries are attempted against this DNS server first before trying HTTPS" )
36
+
37
+ suffixServer = flag .String ("fallback-dns" , "" ,
38
+ "If set all failed (i.e. NXDOMAIN and others) DNS queries are attempted against this DNS server after HTTPS fails." ) // nolint: lll
39
+
40
+ fallthroughStatuses = flag .String ("fallthrough-statuses" , "NXDOMAIN" ,
41
+ "Comma-separated list of statuses which should cause server fallthrough" )
42
+ neverDefault = flag .String ("no-fallthrough" , "" ,
43
+ "Comma-separated list of suffixes which will not be allowed to fallthrough (most useful with prefix DNS" )
44
+
31
45
//routeList = flag.String("route", "",
32
46
// "List of routes where to send queries (subdomain=IP:port)")
33
47
//routes map[string]string
@@ -157,7 +171,20 @@ func route(w dns.ResponseWriter, req *dns.Msg) {
157
171
// return
158
172
// }
159
173
//}
160
- proxy (* defaultServer , w , req )
174
+
175
+ fallthroughs := make (map [int ]struct {})
176
+ for _ , v := range strings .Split (* fallthroughStatuses , "," ) {
177
+ rcode , found := dns .StringToRcode [v ]
178
+ if ! found {
179
+ log .Fatalln ("Could not find matching Rcode integer for" , v )
180
+ }
181
+
182
+ fallthroughs [rcode ] = struct {}{}
183
+ }
184
+
185
+ noFallthrough := strings .Split (* neverDefault , "," )
186
+
187
+ proxy (* defaultServer , * prefixServer , * suffixServer , fallthroughs , noFallthrough , w , req )
161
188
}
162
189
163
190
//func isTransfer(req *dns.Msg) bool {
@@ -183,41 +210,17 @@ func route(w dns.ResponseWriter, req *dns.Msg) {
183
210
// return false
184
211
//}
185
212
186
- func proxy (addr string , w dns.ResponseWriter , req * dns.Msg ) {
187
- var err error
188
- //transport := "udp"
189
- //if _, ok := w.RemoteAddr().(*net.TCPAddr); ok {
190
- // transport = "tcp"
191
- //}
192
- //if isTransfer(req) {
193
- // if transport != "tcp" {
194
- // dns.HandleFailed(w, req)
195
- // return
196
- // }
197
- // t := new(dns.Transfer)
198
- // c, err := t.In(req, addr)
199
- // if err != nil {
200
- // dns.HandleFailed(w, req)
201
- // return
202
- // }
203
- // if err = t.Out(w, req, c); err != nil {
204
- // dns.HandleFailed(w, req)
205
- // return
206
- // }
207
- // return
208
- //}
209
- //c := &dns.Client{Net: "tcp"}
210
- //resp, _, err := c.Exchange(req, addr)
211
- //if err != nil {
212
- // dns.HandleFailed(w, req)
213
- // return
214
- //}
213
+ func dnsRequestProxy (addr string , transport string , req * dns.Msg ) (* dns.Msg , error ) {
214
+ c := & dns.Client {Net : transport }
215
+ resp , _ , err := c .Exchange (req , addr )
216
+ return resp , err
217
+ }
215
218
219
+ func httpDNSRequestProxy (addr string , _ string , req * dns.Msg ) (* dns.Msg , error ) {
216
220
httpreq , err := http .NewRequest (http .MethodGet , addr , nil )
217
221
if err != nil {
218
222
log .Errorln ("Error setting up request:" , err )
219
- dns .HandleFailed (w , req )
220
- return
223
+ return nil , err
221
224
}
222
225
223
226
qry := httpreq .URL .Query ()
@@ -233,9 +236,7 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
233
236
234
237
httpresp , err := http .DefaultClient .Do (httpreq )
235
238
if err != nil {
236
- log .Errorln ("Error sending DNS response:" , err )
237
- dns .HandleFailed (w , req )
238
- return
239
+ return nil , err
239
240
}
240
241
defer httpresp .Body .Close () // nolint: errcheck
241
242
@@ -244,9 +245,7 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
244
245
decoder := json .NewDecoder (httpresp .Body )
245
246
err = decoder .Decode (& dnsResp )
246
247
if err != nil {
247
- log .Errorln ("Malformed JSON DNS response:" , err )
248
- dns .HandleFailed (w , req )
249
- return
248
+ return nil , err
250
249
}
251
250
252
251
// Parse the google Questions to DNS RRs
@@ -298,9 +297,76 @@ func proxy(addr string, w dns.ResponseWriter, req *dns.Msg) {
298
297
Extra : extras ,
299
298
}
300
299
301
- // Write the response
302
- err = w .WriteMsg (& resp )
303
- if err != nil {
304
- log .Errorln ("Error writing DNS response:" , err )
300
+ return & resp , nil
301
+ }
302
+
303
+ func isSuccess (fallthroughStatuses map [int ]struct {}, resp * dns.Msg ) bool {
304
+ if resp == nil {
305
+ return false
306
+ }
307
+ _ , found := fallthroughStatuses [resp .Rcode ]
308
+ return ! found
309
+ }
310
+
311
+ func continueFallthrough (noFallthrough []string , req * dns.Msg ) bool {
312
+ for _ , f := range noFallthrough {
313
+ if f == "" {
314
+ continue
315
+ }
316
+ for _ , q := range req .Question {
317
+ if strings .HasSuffix (q .Name , f ) {
318
+ return false
319
+ }
320
+ }
305
321
}
322
+ return true
323
+ }
324
+
325
+ type proxyFunc func () (* dns.Msg , error )
326
+
327
+ func proxy (addr string , prefixServer string , suffixServer string , fallthroughStatuses map [int ]struct {},
328
+ noFallthrough []string , w dns.ResponseWriter , req * dns.Msg ) {
329
+
330
+ qryCanFallthrough := continueFallthrough (noFallthrough , req )
331
+
332
+ transport := "udp"
333
+ if _ , ok := w .RemoteAddr ().(* net.TCPAddr ); ok {
334
+ transport = "tcp"
335
+ }
336
+
337
+ proxyFuncs := []proxyFunc {}
338
+
339
+ // If prefix server set, try prefix server...
340
+ if prefixServer != "" {
341
+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return dnsRequestProxy (prefixServer , transport , req ) })
342
+
343
+ }
344
+
345
+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return httpDNSRequestProxy (addr , transport , req ) })
346
+
347
+ // If prefix server set, try prefix server...
348
+ if suffixServer != "" {
349
+ proxyFuncs = append (proxyFuncs , func () (* dns.Msg , error ) { return dnsRequestProxy (suffixServer , transport , req ) })
350
+
351
+ }
352
+
353
+ for _ , proxyFunc := range proxyFuncs {
354
+ resp , err := proxyFunc ()
355
+ if err == nil && (isSuccess (fallthroughStatuses , resp ) || ! qryCanFallthrough ) {
356
+ // Write the response
357
+ err = w .WriteMsg (resp )
358
+ if err != nil {
359
+ log .Errorln ("Error writing DNS response:" , err )
360
+ dns .HandleFailed (w , req )
361
+ }
362
+ return
363
+ }
364
+
365
+ if ! qryCanFallthrough {
366
+ dns .HandleFailed (w , req )
367
+ return
368
+ }
369
+ }
370
+
371
+ dns .HandleFailed (w , req )
306
372
}
0 commit comments