Skip to content

Commit

Permalink
Add connectTimeout (#80)
Browse files Browse the repository at this point in the history
* Add connectTimeout

* Cleanup.

* Cleanup2
  • Loading branch information
akhanzode authored Nov 30, 2021
1 parent b98afaf commit e5351c9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
30 changes: 18 additions & 12 deletions voltdbclient/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import (

const (
// DefaultQueryTimeout time out for queries.
DefaultQueryTimeout time.Duration = 2 * time.Minute
DefaultQueryTimeout time.Duration = 2 * time.Minute
DefaultConnectionTimeout time.Duration = 1 * time.Minute
)

var handle int64
Expand Down Expand Up @@ -80,14 +81,14 @@ func newTLSConn(cis []string, clientConfig ClientConfig) (*Conn, error) {
}
c.open.Store(true)

if err := c.start(cis, clientConfig.InsecureSkipVerify); err != nil {
if err := c.startWithTimeout(cis, clientConfig.InsecureSkipVerify, clientConfig.ConnectTimeout); err != nil {
return nil, err
}

return c, nil
}

func newConn(cis []string) (*Conn, error) {
func newConn(cis []string, duration time.Duration) (*Conn, error) {
var c = &Conn{
closeCh: make(chan chan bool),
rl: newTxnLimiter(),
Expand All @@ -97,7 +98,7 @@ func newConn(cis []string) (*Conn, error) {
}
c.open.Store(true)

if err := c.start(cis, false); err != nil {
if err := c.startWithTimeout(cis, false, duration); err != nil {
return nil, err
}

Expand Down Expand Up @@ -139,13 +140,17 @@ func newConn(cis []string) (*Conn, error) {
// This has no effect when retry is false.
//
// retry_interval is the duration of time to wait until the next retry.
func OpenConn(ci string) (*Conn, error) {
func OpenConnWithTimeout(ci string, duration time.Duration) (*Conn, error) {
ci = strings.TrimSpace(ci)
if ci == "" {
return nil, ErrMissingServerArgument
}
cis := strings.Split(ci, ",")
return newConn(cis)
return newConn(cis, duration)
}

func OpenConn(ci string) (*Conn, error) {
return OpenConnWithTimeout(ci, DefaultConnectionTimeout)
}

// OpenTLSConn uses TLS for network connections
Expand All @@ -162,6 +167,7 @@ type ClientConfig struct {
PEMPath string
TLSConfig *tls.Config
InsecureSkipVerify bool
ConnectTimeout time.Duration
}

// OpenConnWithLatencyTarget returns a new connection to the VoltDB server.
Expand All @@ -173,7 +179,7 @@ func OpenConnWithLatencyTarget(ci string, latencyTarget int32) (*Conn, error) {
return nil, ErrMissingServerArgument
}
cis := strings.Split(ci, ",")
c, err := newConn(cis)
c, err := newConn(cis, DefaultConnectionTimeout)
if err != nil {
return nil, err
}
Expand All @@ -191,15 +197,15 @@ func OpenConnWithMaxOutstandingTxns(ci string, maxOutTxns int) (*Conn, error) {
return nil, ErrMissingServerArgument
}
cis := strings.Split(ci, ",")
c, err := newConn(cis)
c, err := newConn(cis, DefaultConnectionTimeout)
if err != nil {
return nil, err
}
c.rl = newTxnLimiterWithMaxOutTxns(maxOutTxns)
return c, nil
}

func (c *Conn) start(cis []string, insecureSkipVerify bool) error {
func (c *Conn) startWithTimeout(cis []string, insecureSkipVerify bool, duration time.Duration) error {
var (
err error
disconnected []*nodeConn
Expand All @@ -214,12 +220,12 @@ func (c *Conn) start(cis []string, insecureSkipVerify bool) error {
if err != nil {
return err
}
nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, PEMBytes)
nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, PEMBytes, duration)
} else {
nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, nil)
nc = newNodeTLSConn(ci, insecureSkipVerify, c.tlsConfig, nil, duration)
}
} else {
nc = newNodeConn(ci)
nc = newNodeConnWithTimeout(ci, duration)
}
if err = nc.connect(ProtocolVersion); err != nil {
disconnected = append(disconnected, nc)
Expand Down
8 changes: 7 additions & 1 deletion voltdbclient/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package voltdbclient
import (
"database/sql"
"database/sql/driver"
"time"
)

// VoltDriver implements A database/sql/driver for VoltDB. This driver is
Expand All @@ -33,7 +34,12 @@ func NewVoltDriver() *VoltDriver {

// Open a connection to the VoltDB server.
func (vd *VoltDriver) Open(hostAndPort string) (driver.Conn, error) {
return OpenConn(hostAndPort)
return vd.OpenWithConnectTimeout(hostAndPort, DefaultConnectionTimeout)
}

// Open a connection to the VoltDB server.
func (vd *VoltDriver) OpenWithConnectTimeout(hostAndPort string, duration time.Duration) (driver.Conn, error) {
return OpenConnWithTimeout(hostAndPort, duration)
}

func init() {
Expand Down
25 changes: 21 additions & 4 deletions voltdbclient/node_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ type nodeConn struct {
// giving up.
maxRetries int
tlsConfig *tls.Config
connectTimeout time.Duration
}

func newNodeConn(ci string) *nodeConn {
func newNodeConnWithTimeout(ci string, duration time.Duration) *nodeConn {
u, _ := parseURL(ci)
return &nodeConn{
connInfo: ci,
Expand All @@ -90,10 +91,15 @@ func newNodeConn(ci string) *nodeConn {
drainCh: make(chan chan bool),
responseCh: make(chan *bytes.Buffer, maxResponseBuffer),
requests: &sync.Map{},
connectTimeout: duration,
}
}

func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, pemBytes []byte) *nodeConn {
func newNodeConn(ci string) *nodeConn {
return newNodeConnWithTimeout(ci, DefaultConnectionTimeout)
}

func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, pemBytes []byte, duration time.Duration) *nodeConn {
u, _ := parseURL(ci)
return &nodeConn{
pemBytes: pemBytes,
Expand All @@ -106,6 +112,7 @@ func newNodeTLSConn(ci string, insecureSkipVerify bool, tlsConfig *tls.Config, p
drainCh: make(chan chan bool),
responseCh: make(chan *bytes.Buffer, maxResponseBuffer),
requests: &sync.Map{},
connectTimeout: duration,
}
}

Expand Down Expand Up @@ -213,6 +220,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn
if err != nil {
return nil, nil, err
}
to := nc.connectTimeout
if to <= 0 {
to = DefaultConnectionTimeout
}
raddr, err := net.ResolveTCPAddr("tcp", u.Host)
if err != nil {
return nil, nil, fmt.Errorf("error resolving %v", nc.Host)
Expand All @@ -230,7 +241,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn
InsecureSkipVerify: nc.insecureSkipVerify,
}
}
conn, err := net.DialTCP("tcp", nil, raddr)
dialer := net.Dialer{
Timeout: to,
}
conn, err := dialer.Dial("tcp", raddr.String())
if err != nil {
return nil, nil, err
}
Expand All @@ -242,7 +256,10 @@ func (nc *nodeConn) networkConnect(protocolVersion int) (interface{}, *wire.Conn
}
return tlsConn, i, nil
}
conn, err := net.DialTCP("tcp", nil, raddr)
dialer := net.Dialer{
Timeout: nc.connectTimeout,
}
conn, err := dialer.Dial("tcp", raddr.String())
if err != nil {
return nil, nil, err
}
Expand Down

0 comments on commit e5351c9

Please sign in to comment.