|
1 | 1 | package dns
|
2 | 2 |
|
3 |
| -import "github.com/miekg/dns" |
| 3 | +import ( |
| 4 | + "github.com/sagernet/sing/common/buf" |
4 | 5 |
|
5 |
| -func TruncateDNSMessage(request *dns.Msg, response *dns.Msg) (*dns.Msg, int) { |
| 6 | + "github.com/miekg/dns" |
| 7 | +) |
| 8 | + |
| 9 | +func TruncateDNSMessage(request *dns.Msg, response *dns.Msg, headroom int) (*buf.Buffer, error) { |
6 | 10 | maxLen := 512
|
7 | 11 | if edns0Option := request.IsEdns0(); edns0Option != nil {
|
8 |
| - if udpSize := int(edns0Option.UDPSize()); udpSize > 0 { |
| 12 | + if udpSize := int(edns0Option.UDPSize()); udpSize > 512 { |
9 | 13 | maxLen = udpSize
|
10 | 14 | }
|
11 | 15 | }
|
12 |
| - return truncateDNSMessage(response, maxLen) |
13 |
| -} |
14 |
| - |
15 |
| -func truncateDNSMessage(response *dns.Msg, maxLen int) (*dns.Msg, int) { |
16 | 16 | responseLen := response.Len()
|
17 |
| - if responseLen <= maxLen { |
18 |
| - return response, responseLen |
19 |
| - } |
20 |
| - newResponse := *response |
21 |
| - response = &newResponse |
22 |
| - response.Compress = true |
23 |
| - responseLen = response.Len() |
24 |
| - if responseLen <= maxLen { |
25 |
| - return response, responseLen |
26 |
| - } |
27 |
| - for len(response.Answer) > 0 && responseLen > maxLen { |
28 |
| - response.Answer = response.Answer[:len(response.Answer)-1] |
29 |
| - response.Truncated = true |
30 |
| - responseLen = response.Len() |
31 |
| - } |
32 | 17 | if responseLen > maxLen {
|
33 |
| - response.Ns = nil |
34 |
| - response.Extra = nil |
| 18 | + response.Truncate(maxLen) |
| 19 | + } |
| 20 | + buffer := buf.NewSize(headroom*2 + 1 + responseLen) |
| 21 | + buffer.Resize(headroom, 0) |
| 22 | + rawMessage, err := response.PackBuffer(buffer.FreeBytes()) |
| 23 | + if err != nil { |
| 24 | + buffer.Release() |
| 25 | + return nil, err |
35 | 26 | }
|
36 |
| - return response, response.Len() |
| 27 | + buffer.Truncate(len(rawMessage)) |
| 28 | + return buffer, nil |
37 | 29 | }
|
0 commit comments