Skip to content

Commit 9635def

Browse files
Prometheus2677FPiety0521
authored and
FPiety0521
committed
Parse out custom TLS params and register the TLS config before calling mysql.ParseDSN
Fixes: golang-migrate/migrate#411
1 parent 2501e7c commit 9635def

File tree

2 files changed

+108
-46
lines changed

2 files changed

+108
-46
lines changed

database/mysql/mysql.go

+62-46
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,68 @@ func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
118118
}
119119

120120
func urlToMySQLConfig(url string) (*mysql.Config, error) {
121+
// Need to parse out custom TLS parameters and call
122+
// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
123+
// which consumes the registered tls.Config
124+
// Fixes: https://github.com/golang-migrate/migrate/issues/411
125+
//
126+
// Can't use url.Parse() since it fails to parse MySQL DSNs
127+
// mysql.ParseDSN() also searches for "?" to find query parameters:
128+
// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
129+
if idx := strings.LastIndex(url, "?"); idx > 0 {
130+
rawParams := url[idx+1:]
131+
parsedParams, err := nurl.ParseQuery(rawParams)
132+
if err != nil {
133+
return nil, err
134+
}
135+
136+
ctls := parsedParams.Get("tls")
137+
if len(ctls) > 0 {
138+
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
139+
rootCertPool := x509.NewCertPool()
140+
pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
141+
if err != nil {
142+
return nil, err
143+
}
144+
145+
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
146+
return nil, ErrAppendPEM
147+
}
148+
149+
clientCert := make([]tls.Certificate, 0, 1)
150+
if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
151+
if ccert == "" || ckey == "" {
152+
return nil, ErrTLSCertKeyConfig
153+
}
154+
certs, err := tls.LoadX509KeyPair(ccert, ckey)
155+
if err != nil {
156+
return nil, err
157+
}
158+
clientCert = append(clientCert, certs)
159+
}
160+
161+
insecureSkipVerify := false
162+
insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
163+
if len(insecureSkipVerifyStr) > 0 {
164+
x, err := strconv.ParseBool(insecureSkipVerifyStr)
165+
if err != nil {
166+
return nil, err
167+
}
168+
insecureSkipVerify = x
169+
}
170+
171+
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
172+
RootCAs: rootCertPool,
173+
Certificates: clientCert,
174+
InsecureSkipVerify: insecureSkipVerify,
175+
})
176+
if err != nil {
177+
return nil, err
178+
}
179+
}
180+
}
181+
}
182+
121183
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
122184
if err != nil {
123185
return nil, err
@@ -140,52 +202,6 @@ func urlToMySQLConfig(url string) (*mysql.Config, error) {
140202
}
141203
config.Passwd = password
142204

143-
// use custom TLS?
144-
ctls := config.TLSConfig
145-
if len(ctls) > 0 {
146-
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
147-
rootCertPool := x509.NewCertPool()
148-
pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
149-
if err != nil {
150-
return nil, err
151-
}
152-
153-
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
154-
return nil, ErrAppendPEM
155-
}
156-
157-
clientCert := make([]tls.Certificate, 0, 1)
158-
if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
159-
if ccert == "" || ckey == "" {
160-
return nil, ErrTLSCertKeyConfig
161-
}
162-
certs, err := tls.LoadX509KeyPair(ccert, ckey)
163-
if err != nil {
164-
return nil, err
165-
}
166-
clientCert = append(clientCert, certs)
167-
}
168-
169-
insecureSkipVerify := false
170-
if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
171-
x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
172-
if err != nil {
173-
return nil, err
174-
}
175-
insecureSkipVerify = x
176-
}
177-
178-
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
179-
RootCAs: rootCertPool,
180-
Certificates: clientCert,
181-
InsecureSkipVerify: insecureSkipVerify,
182-
})
183-
if err != nil {
184-
return nil, err
185-
}
186-
}
187-
}
188-
189205
return config, nil
190206
}
191207

database/mysql/mysql_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@ package mysql
22

33
import (
44
"context"
5+
"crypto/ed25519"
6+
"crypto/x509"
57
"database/sql"
68
sqldriver "database/sql/driver"
9+
"encoding/pem"
710
"errors"
811
"fmt"
12+
"io/ioutil"
913
"log"
14+
"math/big"
15+
"math/rand"
16+
"net/url"
17+
"os"
1018
"strconv"
1119
"testing"
1220
)
@@ -284,7 +292,42 @@ func TestExtractCustomQueryParams(t *testing.T) {
284292
}
285293
}
286294

295+
func createTmpCert(t *testing.T) string {
296+
tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert")
297+
if err != nil {
298+
t.Fatal("Failed to create temp cert file:", err)
299+
}
300+
t.Cleanup(func() {
301+
if err := os.Remove(tmpCertFile.Name()); err != nil {
302+
t.Log("Failed to cleanup temp cert file:", err)
303+
}
304+
})
305+
306+
r := rand.New(rand.NewSource(0))
307+
pub, priv, err := ed25519.GenerateKey(r)
308+
if err != nil {
309+
t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
310+
}
311+
tmpl := x509.Certificate{
312+
SerialNumber: big.NewInt(0),
313+
}
314+
derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
315+
if err != nil {
316+
t.Fatal("Failed to generate temp cert file:", err)
317+
}
318+
if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
319+
t.Fatal("Failed to encode ")
320+
}
321+
if err := tmpCertFile.Close(); err != nil {
322+
t.Fatal("Failed to close temp cert file:", err)
323+
}
324+
return tmpCertFile.Name()
325+
}
326+
287327
func TestURLToMySQLConfig(t *testing.T) {
328+
tmpCertFilename := createTmpCert(t)
329+
tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
330+
288331
testcases := []struct {
289332
name string
290333
urlStr string
@@ -315,6 +358,9 @@ func TestURLToMySQLConfig(t *testing.T) {
315358
{name: "user/password - password with encoded @",
316359
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
317360
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
361+
{name: "custom tls",
362+
urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
363+
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
318364
}
319365
for _, tc := range testcases {
320366
t.Run(tc.name, func(t *testing.T) {

0 commit comments

Comments
 (0)