Skip to content

Commit 33b7747

Browse files
Add BeforeConnect callback to configuration object (#1469)
This can be used to alter the connection options for each connection, right before it's established Co-authored-by: Inada Naoki <[email protected]>
1 parent 6964272 commit 33b7747

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ GitHub Inc.
132132
Google Inc.
133133
InfoSum Ltd.
134134
Keybase Inc.
135+
Microsoft Corp.
135136
Multiplay Ltd.
136137
Percona LLC
137138
PingCAP Inc.

connector.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,22 @@ func newConnector(cfg *Config) *connector {
6666
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6767
var err error
6868

69+
// Invoke beforeConnect if present, with a copy of the configuration
70+
cfg := c.cfg
71+
if c.cfg.beforeConnect != nil {
72+
cfg = c.cfg.Clone()
73+
err = c.cfg.beforeConnect(ctx, cfg)
74+
if err != nil {
75+
return nil, err
76+
}
77+
}
78+
6979
// New mysqlConn
7080
mc := &mysqlConn{
7181
maxAllowedPacket: maxPacketSize,
7282
maxWriteSize: maxPacketSize - 1,
7383
closech: make(chan struct{}),
74-
cfg: c.cfg,
84+
cfg: cfg,
7585
connector: c,
7686
}
7787
mc.parseTime = mc.cfg.ParseTime

driver_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,40 @@ func TestCustomDial(t *testing.T) {
20442044
}
20452045
}
20462046

2047+
func TestBeforeConnect(t *testing.T) {
2048+
if !available {
2049+
t.Skipf("MySQL server not running on %s", netAddr)
2050+
}
2051+
2052+
// dbname is set in the BeforeConnect handle
2053+
cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_"))
2054+
if err != nil {
2055+
t.Fatalf("error parsing DSN: %v", err)
2056+
}
2057+
2058+
cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error {
2059+
c.DBName = dbname
2060+
return nil
2061+
}))
2062+
2063+
connector, err := NewConnector(cfg)
2064+
if err != nil {
2065+
t.Fatalf("error creating connector: %v", err)
2066+
}
2067+
2068+
db := sql.OpenDB(connector)
2069+
defer db.Close()
2070+
2071+
var connectedDb string
2072+
err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb)
2073+
if err != nil {
2074+
t.Fatalf("error executing query: %v", err)
2075+
}
2076+
if connectedDb != dbname {
2077+
t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb)
2078+
}
2079+
}
2080+
20472081
func TestSQLInjection(t *testing.T) {
20482082
createTest := func(arg string) func(dbt *DBTest) {
20492083
return func(dbt *DBTest) {

dsn.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"context"
1314
"crypto/rsa"
1415
"crypto/tls"
1516
"errors"
@@ -71,8 +72,9 @@ type Config struct {
7172

7273
// unexported fields. new options should be come here
7374

74-
pubKey *rsa.PublicKey // Server public key
75-
timeTruncate time.Duration // Truncate time.Time values to the specified duration
75+
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
76+
pubKey *rsa.PublicKey // Server public key
77+
timeTruncate time.Duration // Truncate time.Time values to the specified duration
7678
}
7779

7880
// Functional Options Pattern
@@ -112,6 +114,14 @@ func TimeTruncate(d time.Duration) Option {
112114
}
113115
}
114116

117+
// BeforeConnect sets the function to be invoked before a connection is established.
118+
func BeforeConnect(fn func(context.Context, *Config) error) Option {
119+
return func(cfg *Config) error {
120+
cfg.beforeConnect = fn
121+
return nil
122+
}
123+
}
124+
115125
func (cfg *Config) Clone() *Config {
116126
cp := *cfg
117127
if cp.TLS != nil {

0 commit comments

Comments
 (0)