diff --git a/examples/tls/foo.pem b/examples/tls/foo.pem new file mode 100644 index 0000000..2b4be00 --- /dev/null +++ b/examples/tls/foo.pem @@ -0,0 +1,60 @@ +Bag Attributes + friendlyName: example + localKeyID: 54 69 6D 65 20 31 36 33 36 39 39 39 38 33 32 35 34 36 +Key Attributes: +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFDjBABgkqhkiG9w0BBQ0wMzAbBgkqhkiG9w0BBQwwDgQIUxQvRQEi5IwCAggA +MBQGCCqGSIb3DQMHBAjHeWCQiSxHuASCBMjXArMh0EaguG7HW1I+ZsTAD5EJGU2C +QbClEedcPGYiDoQbHcz06rrFQPwds03QkS/M15I1QGoVxFjIgKVQsLd9AI01vwZd +15wbDJhKei7Q+YCmkxrjhyndBk4mwyls1Hfas4BafdLUgPanBM3Sb9FZEM63odNx +/RRnwrm1jlMTd8bgZgg4UBOK/vgT8CVDCQ98SB5jJEvXyPP8OMM4FFST/Ea2H+Ze +f+QBPo2IvQgvVwmUdGlAE/UWMe7E3cq1m5EnA4qUIAzRNXdnuOrZuP355LENqAoX +gCkneQKYUDTAGNYe2N0BkmvPGeBMSGgxP58jgHyCHhaCAVfioy6f7+4qF03WqxVe +VWhfgCdvo9AnOIn4iOG6wVPeQv55oZMosEPLe+E+1uUJDe0ppaBJRNO3hFlk4j/O +Hy84mNN78Sq8s+m2AMQvrlaLtJmcag0jfrCYv9/EN7MVMeTRjA1/DCGRf+Yv68DL +WJZCcj9NBMLY0ZG+fI5biII8OTOTPHn6l5dQ+fNzWYGcDSHpdjgrNCysxVZz56WZ +cq++jYmq6OyFxRWQE/Og25VvolKjkGCCFQp+53PiXs73dT2EppQzoUjEepw3Lv5c +nQwHTUkd7EMP1DzFTlU2QbeXkuLfEzZnT3tZjbbRBsJ4yyTtxLc2HWQUpPeV164f +Y731R7mLryIWv+FxwIjR3LSP75C3zxN/ypequBYV1xJlzjd0oJBgrXpM5yg6Y5YH +it9QHNYs1tmU/28w+4Nv3zQC6sY2ZK2+yWiZUSKFRoOPKaAj6EmH/E4fCa5HgmVS +XM5Kk+MzoRYU6AIG3rkvQXNYHuuaeUaQY2MtEtCZPvYb2N3usS3wyhUQ8eD1tFEt +jmrge8h7zAN19hyTggQflx2rbLBJI7Caghl/h3Lcj6aGZZq1/lesY4SuJ3yEhy/Q +cNKIRQ+h+40LABZ+PO+Zww2oSVS3UNU5OSZuBP5jnNhoYIGi7CONepIJNzv11mmv +D/t6SFheddG572fPU27KiyQ/Ooiavyv8qTBhSqtCTxoi5zQ+dQpTrtomL2g/vFJT +1ah84GFIo1PAdRT9lxFaznmPWFPh3HSpNo7BsrbHXFD/cRtF+DBUb/olMHf4DFQ4 +XdwcK/+xAOp0eYWwo6RSt+CBP9905m3byrd7cbb71zdbUWxpSn7gb9nwtFn+EGEn +ct1YaJ5W7xiGcb1TpJT4b/Y4PKXeddCUNq3SyJY7n6wvmxUtwYd/LeX38/WS0+FI +pnMMYjrkKmyAoj3cZpz0pXUk26bpLCWxPhUmZDuSPrfEQbgW3RJ6iHUmy44EFS+e +Mn72/i7CuEJuq0c5C8uBkxjOfyMNtVVhr1NYSw0ebtRk5WivzvMDH8Fy5hN3EmE1 +6GTbETJGnL8w43EpoC5OnWcli993RpxZVNU12I808+9JgLQqq+uc6jCxLSJJRIAc +TddSrCArKXTDT0Bbom2gII0gvkkmkxVGC1B2OTAVkIU29IHBCh4MZdXmPepnAaZu +qJAyDdVSgg/k0ao1VZTGzC4dHcnUvl9h6eYiGTCy/N2ATCbLmSdltIr01IbptDZ8 +KQSqVpMA0R/5qATg0x33B3EpReKWFEIuF9eK/b/TEJsEr9Mxq+rs3I/TD9kRMd8k +YkQ= +-----END ENCRYPTED PRIVATE KEY----- +Bag Attributes + friendlyName: example + localKeyID: 54 69 6D 65 20 31 36 33 36 39 39 39 38 33 32 35 34 36 +subject=/C=Unknown/ST=Unknown/L=Unknown/O=Unknown/OU=Unknown/CN=Unknown +issuer=/C=Unknown/ST=Unknown/L=Unknown/O=Unknown/OU=Unknown/CN=Unknown +-----BEGIN CERTIFICATE----- +MIIDdzCCAl+gAwIBAgIEY6rrgzANBgkqhkiG9w0BAQsFADBsMRAwDgYDVQQGEwdV +bmtub3duMRAwDgYDVQQIEwdVbmtub3duMRAwDgYDVQQHEwdVbmtub3duMRAwDgYD +VQQKEwdVbmtub3duMRAwDgYDVQQLEwdVbmtub3duMRAwDgYDVQQDEwdVbmtub3du +MB4XDTIxMTExNTE4MDYwM1oXDTIyMTExNTE4MDYwM1owbDEQMA4GA1UEBhMHVW5r +bm93bjEQMA4GA1UECBMHVW5rbm93bjEQMA4GA1UEBxMHVW5rbm93bjEQMA4GA1UE +ChMHVW5rbm93bjEQMA4GA1UECxMHVW5rbm93bjEQMA4GA1UEAxMHVW5rbm93bjCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMBaovqLoe5Jg5UT5ckBSuL9 +zZthmNUP76QTGSZhdVJU/bnf2DCJDwx0xOmZtdQ/pbupv5KcXmEh27OPNyIrN/G/ +lpiUiquMjSf44KCnKZwBM/QPiVTZvEWAlKyJu4nzTNtagLj+kIybPQzPIimw3g9f +RYXkHAANe82xodPuTUsIV5+T15IXBYyGSCM+nEtpA3KE9Tjt4jji1jpsprbzfnuD +fIxnMlRoFklT3AoQclM2R+IhwDxhGmh/8uH4EDoHTjwMvo73RJh+ffO5LrJL1Jwa +y8zQLS5C1omPwQmEwBpetBXME90S/3o68zrFU8ffKiJyReZxtWtcVH+wkTInMF8C +AwEAAaMhMB8wHQYDVR0OBBYEFEIkA/Ea9liBA6wZu7brPGh6RON9MA0GCSqGSIb3 +DQEBCwUAA4IBAQCsMxXgiP4cSvxjzTeW/Kp84NS3sXGX0Ia/bMtPDZYkzkIMoDoW +z8xrsHqWB9E4VwvmsAtTWV2ICwcSOiSym41CevOpFAHeRjQDZkwqQ9nOL2HmKji5 +wlCnDxB+Yetr4aPCrYwYjLoo2Ge77uOUneu3LJMJKYwRFB8s5U1r5QlaIVusDeOy +AIALlr0KFfqxVNXZ8WRvHBTHYnNECto+DXIqpkIC8s2wyxiWKVS2RUYdFxQLo3NP +SQNRU5uiv6dc7EQA67IY8SQtKB+wiMyxxIVMnBsyy3d3iepQcHmR3FH7SIWr79cn +HPQF833pXriJ8Lw0wdU8IT+p9eSWkaY2Ce2J +-----END CERTIFICATE----- diff --git a/examples/tls/main_test.go b/examples/tls/main_test.go new file mode 100644 index 0000000..f788da6 --- /dev/null +++ b/examples/tls/main_test.go @@ -0,0 +1,35 @@ +package main + +import ( + "testing" + + "database/sql/driver" + + "github.com/kr/pretty" + "github.com/stretchr/testify/assert" + + "github.com/VoltDB/voltdb-client-go/voltdbclient" +) + +func TestMain(t *testing.T) { + + conn, err := voltdbclient.OpenTLSConn("127.0.0.1", voltdbclient.ClientConfig{"foo.pem", false}) + assert.NotNil(t, err) + assert.Nil(t, conn) + + conn, err = voltdbclient.OpenTLSConn("127.0.0.1", voltdbclient.ClientConfig{"foo.pem", true}) + assert.Nil(t, err) + assert.NotNil(t, conn) + + var params []driver.Value + + for _, s := range []interface{}{"PAUSE_CHECK", int32(0)} { + params = append(params, s) + } + + vr, err := conn.Query("@Statistics", params) + assert.Nil(t, err) + assert.NotNil(t, vr) + + pretty.Print(vr) +} diff --git a/go.mod b/go.mod index 5134f2f..743c3c1 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module github.com/VoltDB/voltdb-client-go go 1.16 -require github.com/spaolacci/murmur3 v1.1.0 +require ( + github.com/kr/pretty v0.3.0 // indirect + github.com/spaolacci/murmur3 v1.1.0 + github.com/stretchr/testify v1.7.0 // indirect +) diff --git a/go.sum b/go.sum index c14ec85..6156e5c 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,25 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/voltdbclient/distributor.go b/voltdbclient/distributor.go index 9e2341c..3e8d51e 100644 --- a/voltdbclient/distributor.go +++ b/voltdbclient/distributor.go @@ -24,6 +24,7 @@ import ( "log" "math/rand" "strings" + "io/ioutil" "sync/atomic" "time" ) @@ -45,6 +46,7 @@ var ProtocolVersion = 1 // Conn holds the set of currently active connections. type Conn struct { + pemPath string closeCh chan chan bool open atomic.Value rl rateLimiter @@ -64,6 +66,24 @@ type Conn struct { partitionMasters map[int]*nodeConn } +func newTLSConn(cis []string, clientConfig ClientConfig) (*Conn, error) { + var c = &Conn{ + pemPath: clientConfig.PEMPath, + closeCh: make(chan chan bool), + rl: newTxnLimiter(), + drainCh: make(chan chan bool), + useClientAffinity: true, + partitionMasters: make(map[int]*nodeConn), + } + c.open.Store(true) + + if err := c.start(cis, clientConfig.InsecureSkipVerify); err != nil { + return nil, err + } + + return c, nil +} + func newConn(cis []string) (*Conn, error) { var c = &Conn{ closeCh: make(chan chan bool), @@ -74,7 +94,7 @@ func newConn(cis []string) (*Conn, error) { } c.open.Store(true) - if err := c.start(cis); err != nil { + if err := c.start(cis, false); err != nil { return nil, err } @@ -125,6 +145,21 @@ func OpenConn(ci string) (*Conn, error) { return newConn(cis) } +// OpenTLSConn uses TLS for network connections +func OpenTLSConn(ci string, clientConfig ClientConfig) (*Conn, error) { + ci = strings.TrimSpace(ci) + if ci == "" { + return nil, ErrMissingServerArgument + } + cis := strings.Split(ci, ",") + return newTLSConn(cis, clientConfig) +} + +type ClientConfig struct { + PEMPath string + InsecureSkipVerify bool +} + // OpenConnWithLatencyTarget returns a new connection to the VoltDB server. // This connection will try to meet the specified latency target, potentially by // throttling the rate at which asynchronous transactions are submitted. @@ -160,7 +195,7 @@ func OpenConnWithMaxOutstandingTxns(ci string, maxOutTxns int) (*Conn, error) { return c, nil } -func (c *Conn) start(cis []string) error { +func (c *Conn) start(cis []string, insecureSkipVerify bool) error { var ( err error disconnected []*nodeConn @@ -168,7 +203,16 @@ func (c *Conn) start(cis []string) error { ) for _, ci := range cis { - nc := newNodeConn(ci) + var nc *nodeConn + if len(c.pemPath) > 0 { + pemBytes, err := ioutil.ReadFile(c.pemPath) + if err != nil { + return err + } + nc = newNodeTLSConn(ci, insecureSkipVerify, pemBytes) + } else { + nc = newNodeConn(ci) + } if err = nc.connect(ProtocolVersion); err != nil { disconnected = append(disconnected, nc) continue diff --git a/voltdbclient/node_conn.go b/voltdbclient/node_conn.go index 3159d8c..8ede693 100644 --- a/voltdbclient/node_conn.go +++ b/voltdbclient/node_conn.go @@ -19,7 +19,6 @@ package voltdbclient import ( "bytes" - "database/sql/driver" "errors" "fmt" "io" @@ -28,9 +27,14 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" + "database/sql/driver" + "crypto/x509" + "crypto/tls" + "sync/atomic" + "net/url" + "github.com/VoltDB/voltdb-client-go/wire" ) @@ -45,9 +49,12 @@ const defaultMaxRetries = 10 const defaultRetryInterval = time.Second type nodeConn struct { + pemBytes []byte + insecureSkipVerify bool connInfo string connData *wire.ConnInfo tcpConn *net.TCPConn + tlsConn *tls.Conn drainCh chan chan bool bpCh chan chan bool closeCh chan chan bool @@ -81,6 +88,19 @@ func newNodeConn(ci string) *nodeConn { } } +func newNodeTLSConn(ci string, insecureSkipVerify bool, pemBytes []byte) *nodeConn { + return &nodeConn{ + pemBytes: pemBytes, + insecureSkipVerify: insecureSkipVerify, + connInfo: ci, + bpCh: make(chan chan bool), + closeCh: make(chan chan bool), + drainCh: make(chan chan bool), + responseCh: make(chan *bytes.Buffer, maxResponseBuffer), + requests: &sync.Map{}, + } +} + func (nc *nodeConn) submit(pi *procedureInvocation) (int, error) { if nc.isClosed() { return 0, fmt.Errorf("%s:%d writing on a closed node connection", @@ -91,8 +111,14 @@ func (nc *nodeConn) submit(pi *procedureInvocation) (int, error) { func (nc *nodeConn) markClosed() { nc.closed.Store(true) - nc.tcpConn.Close() + if nc.tcpConn != nil { + nc.tcpConn.Close() + } + + if nc.tlsConn != nil { + nc.tlsConn.Close() + } // release all stored pending requests. This connection is closed so we can't // satisfy the requests. // @@ -115,18 +141,25 @@ func (nc *nodeConn) close() chan bool { } func (nc *nodeConn) connect(protocolVersion int) error { - tcpConn, connData, err := nc.networkConnect(protocolVersion) + connInterface, connData, err := nc.networkConnect(protocolVersion) if err != nil { return err } - nc.connData = connData - nc.tcpConn = tcpConn - - go nc.listen() + nc.connData = connData nc.drainCh = make(chan chan bool, 1) - go nc.loop(nc.bpCh) + switch c := connInterface.(type) { + case *net.TCPConn: + nc.tcpConn = c + go nc.listen() + go nc.loop() + case *tls.Conn: + nc.tlsConn = c + go nc.listenTLS() + go nc.loopTLS() + } + return nil } @@ -155,16 +188,11 @@ func (nc *nodeConn) reconnect(protocolVersion int) { if count > maxRetries { return } - tcpConn, connData, err := nc.networkConnect(protocolVersion) - if err != nil { + if err := nc.connect(protocolVersion); err != nil { log.Println(fmt.Printf("Failed to reconnect to server %s with %s, retrying ...%d\n", nc.connInfo, err, count)) count++ continue } - nc.tcpConn = tcpConn - nc.connData = connData - go nc.listen() - go nc.loop(nc.bpCh) nc.closed.Store(false) return } @@ -172,7 +200,7 @@ func (nc *nodeConn) reconnect(protocolVersion int) { } } -func (nc *nodeConn) networkConnect(protocolVersion int) (*net.TCPConn, *wire.ConnInfo, error) { +func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.ConnInfo, error) { u, err := parseURL(nc.connInfo) if err != nil { return nil, nil, err @@ -181,33 +209,59 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (*net.TCPConn, *wire.Con if err != nil { return nil, nil, fmt.Errorf("error resolving %v", nc.connInfo) } - tcpConn, err := net.DialTCP("tcp", nil, raddr) + if nc.pemBytes != nil { + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM(nc.pemBytes) + if !ok { + log.Fatal("failed to parse root certificate") + } + config := &tls.Config{ + RootCAs: roots, + InsecureSkipVerify: nc.insecureSkipVerify, + } + conn, err := net.DialTCP("tcp", nil, raddr) + if err != nil { + return nil, nil, err + } + tlsConn := tls.Client(conn, config) + i, err := nc.setupConn(protocolVersion, u, tlsConn) + if err != nil { + tlsConn.Close() + return nil, nil, err + } + return tlsConn, i, nil + } + conn, err := net.DialTCP("tcp", nil, raddr) + i, err := nc.setupConn(protocolVersion, u, conn) if err != nil { - return nil, nil, fmt.Errorf("failed to connect to server %v", nc.connInfo) + conn.Close() + return nil, nil, err } + return conn, i, nil +} + +func (nc *nodeConn) setupConn(protocolVersion int, u *url.URL, tcpConn io.ReadWriter) (*wire.ConnInfo, error) { pass, _ := u.User.Password() encoder := wire.NewEncoder() login, err := encoder.Login(protocolVersion, u.User.Username(), pass) if err != nil { - tcpConn.Close() - return nil, nil, fmt.Errorf("failed to serialize login message %v", nc.connInfo) + return nil, fmt.Errorf("failed to serialize login message %v", nc.connInfo) } _, err = tcpConn.Write(login) if err != nil { - return nil, nil, err + return nil, err } decoder := wire.NewDecoder(tcpConn) i, err := decoder.Login() if err != nil { - tcpConn.Close() - return nil, nil, fmt.Errorf("failed to login to server %v", nc.connInfo) + return nil, fmt.Errorf("failed to login to server %v", nc.connInfo) } query := u.Query() retry := query.Get("retry") if retry != "" { r, err := strconv.ParseBool(retry) if err != nil { - return nil, nil, fmt.Errorf("voltdbclient: failed to parse retry value %v", err) + return nil, fmt.Errorf("voltdbclient: failed to parse retry value %v", err) } nc.retry = r @@ -215,7 +269,7 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (*net.TCPConn, *wire.Con if interval != "" { i, err := time.ParseDuration(interval) if err != nil { - return nil, nil, fmt.Errorf("voltdbclient: failed to parse retry_interval value %v", err) + return nil, fmt.Errorf("voltdbclient: failed to parse retry_interval value %v", err) } nc.retryInterval = i } @@ -223,12 +277,12 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (*net.TCPConn, *wire.Con if maxRetries != "" { max, err := strconv.Atoi(maxRetries) if err != nil { - return nil, nil, fmt.Errorf("voltdbclient: failed to parse max_retries value %v", err) + return nil, fmt.Errorf("voltdbclient: failed to parse max_retries value %v", err) } nc.maxRetries = max } } - return tcpConn, i, nil + return i, nil } func (nc *nodeConn) drain(respCh chan bool) { @@ -273,7 +327,7 @@ func (nc *nodeConn) listen() { } } -func (nc *nodeConn) loop(bpCh <-chan chan bool) { +func (nc *nodeConn) loop() { var draining bool var drainRespCh chan bool @@ -341,7 +395,127 @@ func (nc *nodeConn) loop(bpCh <-chan chan bool) { nc.handleAsyncResponse(handle, resp, req) } - case respBPCh := <-bpCh: + case respBPCh := <-nc.bpCh: + respBPCh <- nc.bp + case drainRespCh = <-nc.drainCh: + draining = true + // check for timed out procedure invocations + case <-tcc: + nc.requests.Range(func(_, v interface{}) bool { + req := v.(*networkRequest) + if time.Now().After(req.submitted.Add(req.timeout)) { + nc.queuedBytes -= req.numBytes + nc.handleTimeout(req) + nc.requests.Delete(req.handle) + } + return true + }) + tcc = time.NewTimer(time.Duration(tci) * time.Nanosecond).C + } + } +} + +// listenTLS listens for messages from the server and calls back a registered listener. +// listenTLS blocks on input from the server and should be run as a go routine. +func (nc *nodeConn) listenTLS() { + d := wire.NewDecoder(nc.tlsConn) + s := &wire.Decoder{} + for { + if nc.isClosed() { + return + } + b, err := d.Message() + if err != nil { + if nc.responseCh == nil { + // exiting + return + } + // TODO: put the error on the channel + // the owner needs to reconnect + return + } + buf := bytes.NewBuffer(b) + s.SetReader(buf) + _, err = s.Byte() + if err != nil { + if nc.responseCh == nil { + return + } + return + } + nc.responseCh <- buf + } +} + +func (nc *nodeConn) loopTLS() { + var draining bool + var drainRespCh chan bool + + var tci = int64(DefaultQueryTimeout / 10) // timeout check interval + tcc := time.NewTimer(time.Duration(tci) * time.Nanosecond).C // timeout check timer channel + + // for ping + var pingTimeout = 2 * time.Minute + pingSentTime := time.Now() + var pingOutstanding bool + for { + if nc.isClosed() { + return + } + // setup select cases + if draining { + if nc.queuedBytes <= 0 { + drainRespCh <- true + drainRespCh = nil + draining = false + } + } + + // ping + pingSinceSent := time.Now().Sub(pingSentTime) + if pingOutstanding { + if pingSinceSent > pingTimeout { + // TODO: should disconnect + } + } else if pingSinceSent > pingTimeout/3 { + nc.sendPing() + pingOutstanding = true + pingSentTime = time.Now() + } + + select { + case respCh := <-nc.closeCh: + nc.tlsConn.Close() + respCh <- true + return + case resp := <-nc.responseCh: + decoder := wire.NewDecoder(resp) + handle, err := decoder.Int64() + // can't do anything without a handle. If reading the handle fails, + // then log and drop the message. + if err != nil { + continue + } + if handle == PingHandle { + pingOutstanding = false + continue + } + r, ok := nc.requests.Load(handle) + if !ok || r == nil { + // there's a race here with timeout. A request can be timed out and + // then a response received. In this case drop the response. + continue + } + req := r.(*networkRequest) + nc.queuedBytes -= req.numBytes + nc.requests.Delete(handle) + if req.isSync() { + nc.handleSyncResponse(handle, resp, req) + } else { + nc.handleAsyncResponse(handle, resp, req) + } + + case respBPCh := <-nc.bpCh: respBPCh <- nc.bp case drainRespCh = <-nc.drainCh: draining = true @@ -372,7 +546,13 @@ func (nc *nodeConn) handleProcedureInvocation(pi *procedureInvocation) (int, err nc.queuedBytes += pi.slen encoder := wire.NewEncoder() EncodePI(encoder, pi) - n, err := nc.tcpConn.Write(encoder.Bytes()) + var n int + var err error + if nc.tlsConn == nil { + n, err = nc.tcpConn.Write(encoder.Bytes()) + } else { + n, err = nc.tlsConn.Write(encoder.Bytes()) + } if err != nil { if strings.Contains(err.Error(), "write: broken pipe") { return n, fmt.Errorf("node %s: is down", nc.connInfo) @@ -447,7 +627,12 @@ func (nc *nodeConn) sendPing() error { pi := newProcedureInvocationByHandle(PingHandle, true, "@Ping", []driver.Value{}) encoder := wire.NewEncoder() EncodePI(encoder, pi) - _, err := nc.tcpConn.Write(encoder.Bytes()) + var err error + if nc.tlsConn == nil { + _, err = nc.tcpConn.Write(encoder.Bytes()) + } else { + _, err = nc.tlsConn.Write(encoder.Bytes()) + } if err != nil { if strings.Contains(err.Error(), "write: broken pipe") { return fmt.Errorf("node %s: is down", nc.connInfo)