Skip to content

Commit f2ae871

Browse files
committed
make tests more reliable
1 parent 68ba446 commit f2ae871

File tree

1 file changed

+55
-61
lines changed

1 file changed

+55
-61
lines changed

contrib/miekg/dns/dns_test.go

+55-61
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package dns_test
88
import (
99
"context"
1010
"net"
11+
"sync"
1112
"testing"
1213
"time"
1314

@@ -28,136 +29,134 @@ func (th *testHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
2829
w.WriteMsg(m)
2930
}
3031

31-
func startServer(t *testing.T, traced bool) (*dns.Server, func()) {
32+
func startServer(t *testing.T, traced bool) (*dns.Server, string) {
3233
var h dns.Handler = &testHandler{}
3334
if traced {
3435
h = dnstrace.WrapHandler(h)
3536
}
36-
addr := getAddr(t).String()
37-
server := &dns.Server{
38-
Addr: addr,
39-
Net: "udp",
40-
Handler: h,
37+
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
38+
require.NoError(t, err)
39+
40+
srv := &dns.Server{
41+
PacketConn: pc,
42+
ReadTimeout: time.Hour,
43+
WriteTimeout: time.Hour,
44+
Handler: h,
4145
}
4246

43-
// start the server
47+
waitLock := sync.Mutex{}
48+
waitLock.Lock()
49+
srv.NotifyStartedFunc = waitLock.Unlock
50+
4451
go func() {
45-
err := server.ListenAndServe()
46-
if err != nil {
47-
t.Error(err)
48-
}
52+
require.NoError(t, srv.ActivateAndServe())
4953
}()
50-
waitTillUDPReady(addr)
51-
stopServer := func() {
52-
err := server.Shutdown()
53-
assert.NoError(t, err)
54-
}
55-
return server, stopServer
54+
t.Cleanup(func() {
55+
require.NoError(t, srv.Shutdown())
56+
})
57+
58+
waitLock.Lock()
59+
return srv, pc.LocalAddr().String()
5660
}
5761

5862
func TestExchange(t *testing.T) {
59-
server, stopServer := startServer(t, false)
60-
defer stopServer()
63+
_, addr := startServer(t, false)
6164

6265
mt := mocktracer.Start()
6366
defer mt.Stop()
6467

6568
m := newMessage()
6669

67-
_, err := dnstrace.Exchange(m, server.Addr)
68-
assert.NoError(t, err)
70+
_, err := dnstrace.Exchange(m, addr)
71+
require.NoError(t, err)
6972

7073
spans := mt.FinishedSpans()
7174
require.Len(t, spans, 1)
7275
assertClientSpan(t, spans[0])
7376
}
7477

7578
func TestExchangeContext(t *testing.T) {
76-
server, stopServer := startServer(t, false)
77-
defer stopServer()
79+
_, addr := startServer(t, false)
7880

7981
mt := mocktracer.Start()
8082
defer mt.Stop()
8183

8284
m := newMessage()
8385

84-
_, err := dnstrace.ExchangeContext(context.Background(), m, server.Addr)
85-
assert.NoError(t, err)
86+
_, err := dnstrace.ExchangeContext(context.Background(), m, addr)
87+
require.NoError(t, err)
8688

8789
spans := mt.FinishedSpans()
8890
require.Len(t, spans, 1)
8991
assertClientSpan(t, spans[0])
9092
}
9193

9294
func TestExchangeConn(t *testing.T) {
93-
server, stopServer := startServer(t, false)
94-
defer stopServer()
95+
_, addr := startServer(t, false)
9596

9697
mt := mocktracer.Start()
9798
defer mt.Stop()
9899

99100
m := newMessage()
100101

101-
conn, err := net.Dial("udp", server.Addr)
102+
conn, err := net.Dial("udp", addr)
102103
require.NoError(t, err)
103104

104105
_, err = dnstrace.ExchangeConn(conn, m)
105-
assert.NoError(t, err)
106+
require.NoError(t, err)
106107

107108
spans := mt.FinishedSpans()
108109
require.Len(t, spans, 1)
109110
assertClientSpan(t, spans[0])
110111
}
111112

112113
func TestClient_Exchange(t *testing.T) {
113-
server, stopServer := startServer(t, false)
114-
defer stopServer()
114+
_, addr := startServer(t, false)
115115

116116
mt := mocktracer.Start()
117117
defer mt.Stop()
118118

119119
m := newMessage()
120-
121120
client := newTracedClient()
122121

123-
_, _, err := client.Exchange(m, server.Addr)
124-
assert.NoError(t, err)
122+
_, _, err := client.Exchange(m, addr)
123+
require.NoError(t, err)
125124

126125
spans := mt.FinishedSpans()
127126
require.Len(t, spans, 1)
128127
assertClientSpan(t, spans[0])
129128
}
130129

131130
func TestClient_ExchangeContext(t *testing.T) {
132-
server, stopServer := startServer(t, false)
133-
defer stopServer()
131+
_, addr := startServer(t, false)
134132

135133
mt := mocktracer.Start()
136134
defer mt.Stop()
137135

138136
m := newMessage()
139-
140137
client := newTracedClient()
141138

142-
_, _, err := client.ExchangeContext(context.Background(), m, server.Addr)
143-
assert.NoError(t, err)
139+
_, _, err := client.ExchangeContext(context.Background(), m, addr)
140+
require.NoError(t, err)
144141

145142
spans := mt.FinishedSpans()
146143
require.Len(t, spans, 1)
147144
assertClientSpan(t, spans[0])
148145
}
149146

150147
func TestWrapHandler(t *testing.T) {
151-
server, stopServer := startServer(t, true)
148+
_, addr := startServer(t, true)
152149

153150
mt := mocktracer.Start()
154151
defer mt.Stop()
155152

156153
m := newMessage()
157-
_, err := dns.Exchange(m, server.Addr)
158-
assert.NoError(t, err)
154+
client := newClient()
155+
156+
_, _, err := client.Exchange(m, addr)
157+
require.NoError(t, err)
159158

160-
stopServer() // Shutdown server so span is closed after DNS request
159+
waitForSpans(mt, 1)
161160

162161
spans := mt.FinishedSpans()
163162
require.Len(t, spans, 1)
@@ -177,8 +176,12 @@ func newMessage() *dns.Msg {
177176
return m
178177
}
179178

179+
func newClient() *dns.Client {
180+
return &dns.Client{Net: "udp"}
181+
}
182+
180183
func newTracedClient() *dnstrace.Client {
181-
return &dnstrace.Client{Client: &dns.Client{Net: "udp"}}
184+
return &dnstrace.Client{Client: newClient()}
182185
}
183186

184187
func assertClientSpan(t *testing.T, s mocktracer.Span) {
@@ -190,24 +193,15 @@ func assertClientSpan(t *testing.T, s mocktracer.Span) {
190193
assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind))
191194
}
192195

193-
func getAddr(t *testing.T) net.Addr {
194-
li, err := net.Listen("tcp4", "127.0.0.1:0")
195-
if err != nil {
196-
t.Fatal(err)
197-
}
198-
addr := li.Addr()
199-
li.Close()
200-
return addr
201-
}
196+
func waitForSpans(mt mocktracer.Tracer, sz int) {
197+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
198+
defer cancel()
202199

203-
func waitTillUDPReady(addr string) {
204-
deadline := time.Now().Add(time.Second * 10)
205-
for time.Now().Before(deadline) {
206-
m := new(dns.Msg)
207-
m.SetQuestion("miek.nl.", dns.TypeMX)
208-
_, err := dns.Exchange(m, addr)
209-
if err == nil {
210-
break
200+
for len(mt.FinishedSpans()) < sz {
201+
select {
202+
case <-ctx.Done():
203+
return
204+
default:
211205
}
212206
time.Sleep(time.Millisecond * 100)
213207
}

0 commit comments

Comments
 (0)