From 1cff4865e44e0063020a94a28e936c46def168d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 9 Jun 2024 07:15:47 +0800 Subject: [PATCH] Improve TLS transport --- transport_tls.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/transport_tls.go b/transport_tls.go index dcd014f..93eeead 100644 --- a/transport_tls.go +++ b/transport_tls.go @@ -93,21 +93,28 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg t.access.Lock() conn := t.connections.PopFront() t.access.Unlock() - if conn == nil { - tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) - if err != nil { - return nil, err + if conn != nil { + response, err := t.exchange(message, conn) + if err == nil { + return response, nil } - tlsConn := tls.Client(tcpConn, &tls.Config{ - ServerName: t.serverAddr.AddrString(), - }) - err = tlsConn.HandshakeContext(ctx) - if err != nil { - tcpConn.Close() - return nil, err - } - conn = &tlsDNSConn{Conn: tlsConn} } + tcpConn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(tcpConn, &tls.Config{ + ServerName: t.serverAddr.AddrString(), + }) + err = tlsConn.HandshakeContext(ctx) + if err != nil { + tcpConn.Close() + return nil, err + } + return t.exchange(message, &tlsDNSConn{Conn: tlsConn}) +} + +func (t *TLSTransport) exchange(message *dns.Msg, conn *tlsDNSConn) (*dns.Msg, error) { messageId := message.Id conn.queryId++ message.Id = conn.queryId