Skip to content

Commit 3b362f5

Browse files
author
Julien Pivotto
authored
Merge pull request #291 from pracucci/add-custom-dialer-option-to-http-client
Added functional option to allow to customize DialContext() in HTTP client
2 parents 4240322 + 6a9c79c commit 3b362f5

File tree

2 files changed

+99
-19
lines changed

2 files changed

+99
-19
lines changed

config/http_config.go

+70-11
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ package config
1717

1818
import (
1919
"bytes"
20+
"context"
2021
"crypto/sha256"
2122
"crypto/tls"
2223
"crypto/x509"
2324
"fmt"
2425
"io/ioutil"
26+
"net"
2527
"net/http"
2628
"net/url"
2729
"strings"
@@ -38,6 +40,12 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
3840
FollowRedirects: true,
3941
}
4042

43+
// defaultHTTPClientOptions holds the default HTTP client options.
44+
var defaultHTTPClientOptions = httpClientOptions{
45+
keepAlivesEnabled: true,
46+
http2Enabled: true,
47+
}
48+
4149
type closeIdler interface {
4250
CloseIdleConnections()
4351
}
@@ -194,15 +202,50 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
194202
return unmarshal((*plain)(a))
195203
}
196204

205+
// DialContextFunc defines the signature of the DialContext() function implemented
206+
// by net.Dialer.
207+
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
208+
209+
type httpClientOptions struct {
210+
dialContextFunc DialContextFunc
211+
keepAlivesEnabled bool
212+
http2Enabled bool
213+
}
214+
215+
// HTTPClientOption defines an option that can be applied to the HTTP client.
216+
type HTTPClientOption func(options *httpClientOptions)
217+
218+
// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
219+
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
220+
return func(opts *httpClientOptions) {
221+
opts.dialContextFunc = fn
222+
}
223+
}
224+
225+
// WithKeepAlivesDisabled allows to disable HTTP keepalive.
226+
func WithKeepAlivesDisabled() HTTPClientOption {
227+
return func(opts *httpClientOptions) {
228+
opts.keepAlivesEnabled = false
229+
}
230+
}
231+
232+
// WithHTTP2Disabled allows to disable HTTP2.
233+
func WithHTTP2Disabled() HTTPClientOption {
234+
return func(opts *httpClientOptions) {
235+
opts.http2Enabled = false
236+
}
237+
}
238+
197239
// NewClient returns a http.Client using the specified http.RoundTripper.
198240
func newClient(rt http.RoundTripper) *http.Client {
199241
return &http.Client{Transport: rt}
200242
}
201243

202244
// NewClientFromConfig returns a new HTTP client configured for the
203-
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
204-
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) {
205-
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2)
245+
// given config.HTTPClientConfig and config.HTTPClientOption.
246+
// The name is used as go-conntrack metric label.
247+
func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (*http.Client, error) {
248+
rt, err := NewRoundTripperFromConfig(cfg, name, optFuncs...)
206249
if err != nil {
207250
return nil, err
208251
}
@@ -216,29 +259,45 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e
216259
}
217260

218261
// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
219-
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
220-
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) {
262+
// given config.HTTPClientConfig and config.HTTPClientOption.
263+
// The name is used as go-conntrack metric label.
264+
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
265+
opts := defaultHTTPClientOptions
266+
for _, f := range optFuncs {
267+
f(&opts)
268+
}
269+
270+
var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
271+
272+
if opts.dialContextFunc != nil {
273+
dialContext = conntrack.NewDialContextFunc(
274+
conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
275+
conntrack.DialWithTracing(),
276+
conntrack.DialWithName(name))
277+
} else {
278+
dialContext = conntrack.NewDialContextFunc(
279+
conntrack.DialWithTracing(),
280+
conntrack.DialWithName(name))
281+
}
282+
221283
newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
222284
// The only timeout we care about is the configured scrape timeout.
223285
// It is applied on request. So we leave out any timings here.
224286
var rt http.RoundTripper = &http.Transport{
225287
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
226288
MaxIdleConns: 20000,
227289
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
228-
DisableKeepAlives: disableKeepAlives,
290+
DisableKeepAlives: !opts.keepAlivesEnabled,
229291
TLSClientConfig: tlsConfig,
230292
DisableCompression: true,
231293
// 5 minutes is typically above the maximum sane scrape interval. So we can
232294
// use keepalive for all configurations.
233295
IdleConnTimeout: 5 * time.Minute,
234296
TLSHandshakeTimeout: 10 * time.Second,
235297
ExpectContinueTimeout: 1 * time.Second,
236-
DialContext: conntrack.NewDialContextFunc(
237-
conntrack.DialWithTracing(),
238-
conntrack.DialWithName(name),
239-
),
298+
DialContext: dialContext,
240299
}
241-
if enableHTTP2 {
300+
if opts.http2Enabled {
242301
// HTTP/2 support is golang has many problematic cornercases where
243302
// dead connections would be kept and used in connection pools.
244303
// https://github.com/golang/go/issues/32388

config/http_config_test.go

+29-8
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
package config
1717

1818
import (
19+
"context"
1920
"crypto/tls"
2021
"crypto/x509"
22+
"errors"
2123
"fmt"
2224
"io/ioutil"
25+
"net"
2326
"net/http"
2427
"net/http/httptest"
2528
"os"
@@ -50,6 +53,7 @@ const (
5053
MissingKey = "missing/secret.key"
5154

5255
ExpectedMessage = "I'm here to serve you!!!"
56+
ExpectedError = "expected error"
5357
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
5458
AuthorizationCredentialsFile = "testdata/bearer.token"
5559
AuthorizationType = "APIKEY"
@@ -350,7 +354,7 @@ func TestNewClientFromConfig(t *testing.T) {
350354
if err != nil {
351355
t.Fatal(err.Error())
352356
}
353-
client, err := NewClientFromConfig(validConfig.clientConfig, "test", false, true)
357+
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
354358
if err != nil {
355359
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
356360
continue
@@ -400,7 +404,7 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
400404
}
401405

402406
for _, invalidConfig := range newClientInvalidConfig {
403-
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test", false, true)
407+
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test")
404408
if client != nil {
405409
t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig)
406410
}
@@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
413417
}
414418
}
415419

420+
func TestCustomDialContextFunc(t *testing.T) {
421+
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
422+
return nil, errors.New(ExpectedError)
423+
}
424+
425+
cfg := HTTPClientConfig{}
426+
client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn))
427+
if err != nil {
428+
t.Fatalf("Can't create a client from this config: %+v", cfg)
429+
}
430+
431+
_, err = client.Get("http://localhost")
432+
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
433+
t.Errorf("Expected error %q but got %q", ExpectedError, err)
434+
}
435+
}
436+
416437
func TestMissingBearerAuthFile(t *testing.T) {
417438
cfg := HTTPClientConfig{
418439
BearerTokenFile: MissingBearerTokenFile,
@@ -439,7 +460,7 @@ func TestMissingBearerAuthFile(t *testing.T) {
439460
}
440461
defer testServer.Close()
441462

442-
client, err := NewClientFromConfig(cfg, "test", false, true)
463+
client, err := NewClientFromConfig(cfg, "test")
443464
if err != nil {
444465
t.Fatal(err)
445466
}
@@ -637,7 +658,7 @@ func TestBasicAuthNoPassword(t *testing.T) {
637658
if err != nil {
638659
t.Fatalf("Error loading HTTP client config: %v", err)
639660
}
640-
client, err := NewClientFromConfig(*cfg, "test", false, true)
661+
client, err := NewClientFromConfig(*cfg, "test")
641662
if err != nil {
642663
t.Fatalf("Error creating HTTP Client: %v", err)
643664
}
@@ -663,7 +684,7 @@ func TestBasicAuthNoUsername(t *testing.T) {
663684
if err != nil {
664685
t.Fatalf("Error loading HTTP client config: %v", err)
665686
}
666-
client, err := NewClientFromConfig(*cfg, "test", false, true)
687+
client, err := NewClientFromConfig(*cfg, "test")
667688
if err != nil {
668689
t.Fatalf("Error creating HTTP Client: %v", err)
669690
}
@@ -689,7 +710,7 @@ func TestBasicAuthPasswordFile(t *testing.T) {
689710
if err != nil {
690711
t.Fatalf("Error loading HTTP client config: %v", err)
691712
}
692-
client, err := NewClientFromConfig(*cfg, "test", false, true)
713+
client, err := NewClientFromConfig(*cfg, "test")
693714
if err != nil {
694715
t.Fatalf("Error creating HTTP Client: %v", err)
695716
}
@@ -840,7 +861,7 @@ func TestTLSRoundTripper(t *testing.T) {
840861
writeCertificate(bs, tc.cert, cert)
841862
writeCertificate(bs, tc.key, key)
842863
if c == nil {
843-
c, err = NewClientFromConfig(cfg, "test", false, true)
864+
c, err = NewClientFromConfig(cfg, "test")
844865
if err != nil {
845866
t.Fatalf("Error creating HTTP Client: %v", err)
846867
}
@@ -912,7 +933,7 @@ func TestTLSRoundTripperRaces(t *testing.T) {
912933
writeCertificate(bs, TLSCAChainPath, ca)
913934
writeCertificate(bs, ClientCertificatePath, cert)
914935
writeCertificate(bs, ClientKeyNoPassPath, key)
915-
c, err = NewClientFromConfig(cfg, "test", false, true)
936+
c, err = NewClientFromConfig(cfg, "test")
916937
if err != nil {
917938
t.Fatalf("Error creating HTTP Client: %v", err)
918939
}

0 commit comments

Comments
 (0)