Skip to content

Commit

Permalink
Improve TCP/TLS transports
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 2, 2024
1 parent 7fecf77 commit a5c892e
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 34 deletions.
18 changes: 17 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import (
"github.com/miekg/dns"
)

const DefaultTTL = 600
const (
DefaultTTL = 600
DefaultTimeout = 5 * time.Second
)

var (
ErrNoRawSupport = E.New("no raw query support by current transport")
Expand All @@ -26,6 +29,7 @@ var (
)

type Client struct {
timeout time.Duration
disableCache bool
disableExpire bool
independentCache bool
Expand All @@ -48,6 +52,7 @@ type transportCacheKey struct {
}

type ClientOptions struct {
Timeout time.Duration
DisableCache bool
DisableExpire bool
IndependentCache bool
Expand All @@ -57,12 +62,16 @@ type ClientOptions struct {

func NewClient(options ClientOptions) *Client {
client := &Client{
timeout: options.Timeout,
disableCache: options.DisableCache,
disableExpire: options.DisableExpire,
independentCache: options.IndependentCache,
initRDRCFunc: options.RDRC,
logger: options.Logger,
}
if client.timeout == 0 {
client.timeout = DefaultTimeout
}
if !client.disableCache {
if !client.independentCache {
client.cache = cache.New[dns.Question, *dns.Msg]()
Expand Down Expand Up @@ -148,7 +157,14 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return nil, ErrResponseRejectedCached
}
}
var cancel context.CancelFunc
if c.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.timeout)
}
response, err := transport.Exchange(ctx, message)
if cancel != nil {
cancel()
}
if err != nil {
return nil, err
}
Expand Down
75 changes: 45 additions & 30 deletions transport_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ type myTransportAdapter struct {
dialer N.Dialer
serverAddr M.Socksaddr
clientAddr netip.Prefix
reuse bool
handler myTransportHandler
access sync.Mutex
conn *dnsConnection
}

func newAdapter(options TransportOptions, serverAddr M.Socksaddr) myTransportAdapter {
func newAdapter(options TransportOptions, serverAddr M.Socksaddr, reuse bool) myTransportAdapter {
ctx, cancel := context.WithCancel(options.Context)
return myTransportAdapter{
name: options.Name,
Expand All @@ -43,6 +44,7 @@ func newAdapter(options TransportOptions, serverAddr M.Socksaddr) myTransportAda
dialer: options.Dialer,
serverAddr: serverAddr,
clientAddr: options.ClientSubnet,
reuse: reuse,
}
}

Expand All @@ -55,18 +57,21 @@ func (t *myTransportAdapter) Start() error {
}

func (t *myTransportAdapter) open(ctx context.Context) (*dnsConnection, error) {
connection := t.conn
if connection != nil {
if !common.Done(connection.ctx) {
return connection, nil
var connection *dnsConnection
if t.reuse {
connection = t.conn
if connection != nil {
if !common.Done(connection.ctx) {
return connection, nil
}
}
}
t.access.Lock()
defer t.access.Unlock()
connection = t.conn
if connection != nil {
if !common.Done(connection.ctx) {
return connection, nil
t.access.Lock()
defer t.access.Unlock()
connection = t.conn
if connection != nil {
if !common.Done(connection.ctx) {
return connection, nil
}
}
}
conn, err := t.handler.DialContext(ctx)
Expand All @@ -81,7 +86,9 @@ func (t *myTransportAdapter) open(ctx context.Context) (*dnsConnection, error) {
callbacks: make(map[uint16]*dnsCallback),
}
go t.recvLoop(connection)
t.conn = connection
if t.reuse {
t.conn = connection
}
return connection, nil
}

Expand Down Expand Up @@ -110,31 +117,24 @@ func (t *myTransportAdapter) recvLoop(conn *dnsConnection) {
}
})
group.Cleanup(func() {
conn.cancel()
conn.Close()
})
group.Run(conn.ctx)
}

func (t *myTransportAdapter) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) {
messageId := message.Id
var (
conn *dnsConnection
err error
)
for attempts := 0; attempts < 2; attempts++ {
conn, err = t.open(t.ctx)
if err == nil {
break
}
}
conn, err := t.open(t.ctx)
if err != nil {
return nil, err
}
response, err := t.exchange(ctx, conn, message)
if err != nil {
return nil, err
}
if !t.reuse {
conn.Close()
}
response.Id = messageId
return response, nil
}
Expand All @@ -151,11 +151,22 @@ func (t *myTransportAdapter) exchange(ctx context.Context, conn *dnsConnection,
conn.callbacks[exMessage.Id] = callback
conn.access.Unlock()
defer t.cleanup(conn, exMessage.Id, callback)
conn.writeAccess.Lock()
err := t.handler.WriteMessage(conn, &exMessage)
conn.writeAccess.Unlock()
var err error
done := make(chan struct{})
go func() {
conn.writeAccess.Lock()
err = t.handler.WriteMessage(conn, &exMessage)
conn.writeAccess.Unlock()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
conn.Close()
return nil, ctx.Err()
}
if err != nil {
conn.cancel()
conn.Close()
return nil, err
}
select {
Expand All @@ -165,7 +176,7 @@ func (t *myTransportAdapter) exchange(ctx context.Context, conn *dnsConnection,
case <-conn.ctx.Done():
return nil, E.Errors(conn.err, conn.ctx.Err())
case <-ctx.Done():
conn.cancel()
conn.Close()
return nil, ctx.Err()
}
}
Expand All @@ -186,7 +197,6 @@ func (t *myTransportAdapter) cleanup(conn *dnsConnection, messageId uint16, call
func (t *myTransportAdapter) Reset() {
conn := t.conn
if conn != nil {
conn.cancel()
conn.Close()
}
}
Expand Down Expand Up @@ -215,6 +225,11 @@ type dnsConnection struct {
callbacks map[uint16]*dnsCallback
}

func (c *dnsConnection) Close() error {
c.cancel()
return c.Conn.Close()
}

type dnsCallback struct {
access sync.Mutex
message *dns.Msg
Expand Down
2 changes: 1 addition & 1 deletion transport_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewTCPTransport(options TransportOptions) (*TCPTransport, error) {

func newTCPTransport(options TransportOptions, serverAddr M.Socksaddr) *TCPTransport {
transport := &TCPTransport{
newAdapter(options, serverAddr),
newAdapter(options, serverAddr, false),
}
transport.handler = transport
return transport
Expand Down
2 changes: 1 addition & 1 deletion transport_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewTLSTransport(options TransportOptions) (*TLSTransport, error) {
serverAddr.Port = 853
}
transport := &TLSTransport{
newAdapter(options, serverAddr),
newAdapter(options, serverAddr, true),
}
transport.handler = transport
return transport, nil
Expand Down
2 changes: 1 addition & 1 deletion transport_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewUDPTransport(options TransportOptions) (*UDPTransport, error) {
serverAddr.Port = 53
}
transport := &UDPTransport{
newAdapter(options, serverAddr),
newAdapter(options, serverAddr, true),
newTCPTransport(options, serverAddr),
options.Logger,
512,
Expand Down

0 comments on commit a5c892e

Please sign in to comment.