diff --git a/voltdbclient/distributor.go b/voltdbclient/distributor.go index f0298f2..04b9369 100644 --- a/voltdbclient/distributor.go +++ b/voltdbclient/distributor.go @@ -32,7 +32,8 @@ import ( const ( // DefaultQueryTimeout time out for queries. - DefaultQueryTimeout time.Duration = 2 * time.Minute + DefaultQueryTimeout time.Duration = 2 * time.Minute + DefaultConnectionTimeout time.Duration = 1 * time.Minute ) var handle int64 @@ -80,14 +81,14 @@ func newTLSConn(cis []string, clientConfig ClientConfig) (*Conn, error) { } c.open.Store(true) - if err := c.start(cis, clientConfig.InsecureSkipVerify); err != nil { + if err := c.startWithTimeout(cis, clientConfig.InsecureSkipVerify, clientConfig.ConnectTimeout); err != nil { return nil, err } return c, nil } -func newConn(cis []string) (*Conn, error) { +func newConn(cis []string, duration time.Duration) (*Conn, error) { var c = &Conn{ closeCh: make(chan chan bool), rl: newTxnLimiter(), @@ -97,7 +98,7 @@ func newConn(cis []string) (*Conn, error) { } c.open.Store(true) - if err := c.start(cis, false); err != nil { + if err := c.startWithTimeout(cis, false, duration); err != nil { return nil, err } @@ -139,13 +140,17 @@ func newConn(cis []string) (*Conn, error) { // This has no effect when retry is false. // // retry_interval is the duration of time to wait until the next retry. -func OpenConn(ci string) (*Conn, error) { +func OpenConnWithTimeout(ci string, duration time.Duration) (*Conn, error) { ci = strings.TrimSpace(ci) if ci == "" { return nil, ErrMissingServerArgument } cis := strings.Split(ci, ",") - return newConn(cis) + return newConn(cis, duration) +} + +func OpenConn(ci string) (*Conn, error) { + return OpenConnWithTimeout(ci, DefaultConnectionTimeout) } // OpenTLSConn uses TLS for network connections @@ -162,6 +167,7 @@ type ClientConfig struct { PEMPath string TLSConfig *tls.Config InsecureSkipVerify bool + ConnectTimeout time.Duration } // OpenConnWithLatencyTarget returns a new connection to the VoltDB server. @@ -173,7 +179,7 @@ func OpenConnWithLatencyTarget(ci string, latencyTarget int32) (*Conn, error) { return nil, ErrMissingServerArgument } cis := strings.Split(ci, ",") - c, err := newConn(cis) + c, err := newConn(cis, DefaultConnectionTimeout) if err != nil { return nil, err } @@ -191,7 +197,7 @@ func OpenConnWithMaxOutstandingTxns(ci string, maxOutTxns int) (*Conn, error) { return nil, ErrMissingServerArgument } cis := strings.Split(ci, ",") - c, err := newConn(cis) + c, err := newConn(cis, DefaultConnectionTimeout) if err != nil { return nil, err } @@ -199,7 +205,7 @@ func OpenConnWithMaxOutstandingTxns(ci string, maxOutTxns int) (*Conn, error) { return c, nil } -func (c *Conn) start(cis []string, insecureSkipVerify bool) error { +func (c *Conn) startWithTimeout(cis []string, insecureSkipVerify bool, duration time.Duration) error { var ( err error disconnected []*nodeConn @@ -214,12 +220,12 @@ func (c *Conn) start(cis []string, insecureSkipVerify bool) error { if err != nil { return err } - nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, PEMBytes) + nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, PEMBytes, duration) } else { - nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, nil) + nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, nil, duration) } } else { - nc = newNodeConn(ci) + nc = newNodeConnWithTimeout(ci, duration) } if err = nc.connect(ProtocolVersion); err != nil { disconnected = append(disconnected, nc) diff --git a/voltdbclient/driver.go b/voltdbclient/driver.go index 3479f2a..0d3793d 100644 --- a/voltdbclient/driver.go +++ b/voltdbclient/driver.go @@ -20,6 +20,7 @@ package voltdbclient import ( "database/sql" "database/sql/driver" + "time" ) // VoltDriver implements A database/sql/driver for VoltDB. This driver is @@ -33,7 +34,12 @@ func NewVoltDriver() *VoltDriver { // Open a connection to the VoltDB server. func (vd *VoltDriver) Open(hostAndPort string) (driver.Conn, error) { - return OpenConn(hostAndPort) + return vd.OpenWithConnectTimeout(hostAndPort, DefaultConnectionTimeout) +} + +// Open a connection to the VoltDB server. +func (vd *VoltDriver) OpenWithConnectTimeout(hostAndPort string, duration time.Duration) (driver.Conn, error) { + return OpenConnWithTimeout(hostAndPort, duration) } func init() { diff --git a/voltdbclient/node_conn.go b/voltdbclient/node_conn.go index ea17001..59a8e58 100644 --- a/voltdbclient/node_conn.go +++ b/voltdbclient/node_conn.go @@ -78,9 +78,10 @@ type nodeConn struct { // giving up. maxRetries int tlsConfig *tls.Config + connectTimeout time.Duration } -func newNodeConn(ci string) *nodeConn { +func newNodeConnWithTimeout(ci string, duration time.Duration) *nodeConn { u, _ := parseURL(ci) return &nodeConn{ connInfo: ci, @@ -90,10 +91,15 @@ func newNodeConn(ci string) *nodeConn { drainCh: make(chan chan bool), responseCh: make(chan *bytes.Buffer, maxResponseBuffer), requests: &sync.Map{}, + connectTimeout: duration, } } -func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, pemBytes []byte) *nodeConn { +func newNodeConn(ci string) *nodeConn { + return newNodeConnWithTimeout(ci, DefaultConnectionTimeout) +} + +func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, pemBytes []byte, duration time.Duration) *nodeConn { u, _ := parseURL(ci) return &nodeConn{ pemBytes: pemBytes, @@ -106,6 +112,7 @@ func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, p drainCh: make(chan chan bool), responseCh: make(chan *bytes.Buffer, maxResponseBuffer), requests: &sync.Map{}, + connectTimeout: duration, } } @@ -213,6 +220,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn if err != nil { return nil, nil, err } + to := nc.connectTimeout + if to <= 0 { + to = DefaultConnectionTimeout + } raddr, err := net.ResolveTCPAddr("tcp", u.Host) if err != nil { return nil, nil, fmt.Errorf("error resolving %v", nc.Host) @@ -230,7 +241,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn InsecureSkipVerify: nc.insecureSkipVerify, } } - conn, err := net.DialTCP("tcp", nil, raddr) + dialer := net.Dialer{ + Timeout: to, + } + conn, err := dialer.Dial("tcp", raddr.String()) if err != nil { return nil, nil, err } @@ -242,7 +256,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn } return tlsConn, i, nil } - conn, err := net.DialTCP("tcp", nil, raddr) + dialer := net.Dialer{ + Timeout: nc.connectTimeout, + } + conn, err := dialer.Dial("tcp", raddr.String()) if err != nil { return nil, nil, err }