Skip to content

Commit 40ba76e

Browse files
committed
Apply address translator right after hosts resolved
In order for DNS feature to work properly with AddressTranslator we need to address being translated right after DNS is resolved.
1 parent e56c191 commit 40ba76e

File tree

6 files changed

+15
-10
lines changed

6 files changed

+15
-10
lines changed

conn_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ func (srv *TestServer) session() (*Session, error) {
12571257
}
12581258

12591259
func (srv *TestServer) host() *HostInfo {
1260-
hosts, err := hostInfo(nil, srv.Address, 9042)
1260+
hosts, err := hostInfo(nil, nil, srv.Address, 9042)
12611261
if err != nil {
12621262
srv.t.Fatal(err)
12631263
}

control.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func (c *controlConn) heartBeat() {
141141
}
142142
}
143143

144-
func hostInfo(resolver DNSResolver, addr string, defaultPort int) ([]*HostInfo, error) {
144+
func hostInfo(resolver DNSResolver, translateAddressPort func(addr net.IP, port int) (net.IP, int), addr string, defaultPort int) ([]*HostInfo, error) {
145145
var port int
146146
host, portStr, err := net.SplitHostPort(addr)
147147
if err != nil {
@@ -174,7 +174,12 @@ func hostInfo(resolver DNSResolver, addr string, defaultPort int) ([]*HostInfo,
174174

175175
for _, ip := range ips {
176176
if validIpAddr(ip) {
177-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
177+
hh := &HostInfo{hostname: host, connectAddress: ip, port: port}
178+
hh.untranslatedConnectAddress = ip
179+
if translateAddressPort != nil {
180+
hh.connectAddress, hh.port = translateAddressPort(ip, port)
181+
}
182+
hosts = append(hosts, hh)
178183
}
179184
}
180185

@@ -424,7 +429,7 @@ func (c *controlConn) attemptReconnect() error {
424429
c.session.logger.Printf("gocql: control falling back to initial contact points.\n")
425430
// Fallback to initial contact points, as it may be the case that all known initialHosts
426431
// changed their IPs while keeping the same hostname(s).
427-
initialHosts, resolvErr := addrsToHosts(c.session.cfg.DNSResolver, c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
432+
initialHosts, resolvErr := addrsToHosts(c.session.cfg.DNSResolver, c.session.cfg.translateAddressPort, c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
428433
if resolvErr != nil {
429434
return fmt.Errorf("resolve contact points' hostnames: %v", resolvErr)
430435
}

control_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestHostInfo_Lookup(t *testing.T) {
4444
}
4545

4646
for i, test := range tests {
47-
hosts, err := hostInfo(resolver, test.addr, 1)
47+
hosts, err := hostInfo(resolver, nil, test.addr, 1)
4848
if err != nil {
4949
t.Errorf("%d: %v", i, err)
5050
continue

exec.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor,
6969
}
7070

7171
var hosts []*HostInfo
72-
if hosts, err = addrsToHosts(c.DNSResolver, c.Hosts, c.Port, c.Logger); err != nil {
72+
if hosts, err = addrsToHosts(c.DNSResolver, c.translateAddressPort, c.Hosts, c.Port, c.Logger); err != nil {
7373
err = fmt.Errorf("addrs to hosts: %w", err)
7474
return
7575
}

filters.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func DataCentreHostFilter(dataCenter string) HostFilter {
7070
// WhiteListHostFilter filters incoming hosts by checking that their address is
7171
// in the initial hosts whitelist.
7272
func WhiteListHostFilter(hosts ...string) HostFilter {
73-
hostInfos, err := addrsToHosts(defaultDnsResolver, hosts, 9042, nopLogger{})
73+
hostInfos, err := addrsToHosts(defaultDnsResolver, nil, hosts, 9042, nopLogger{})
7474
if err != nil {
7575
// dont want to panic here, but rather not break the API
7676
panic(fmt.Errorf("unable to lookup host info from address: %v", err))

session.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ var queryPool = &sync.Pool{
116116
},
117117
}
118118

119-
func addrsToHosts(resolver DNSResolver, addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) {
119+
func addrsToHosts(resolver DNSResolver, translateAddressPort func(addr net.IP, port int) (net.IP, int), addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) {
120120
var hosts []*HostInfo
121121
for _, hostaddr := range addrs {
122-
resolvedHosts, err := hostInfo(resolver, hostaddr, defaultPort)
122+
resolvedHosts, err := hostInfo(resolver, translateAddressPort, hostaddr, defaultPort)
123123
if err != nil {
124124
// Try other hosts if unable to resolve DNS name
125125
if _, ok := err.(*net.DNSError); ok {
@@ -259,7 +259,7 @@ func (s *Session) init() error {
259259
return nil
260260
}
261261

262-
hosts, err := addrsToHosts(s.cfg.DNSResolver, s.cfg.Hosts, s.cfg.Port, s.logger)
262+
hosts, err := addrsToHosts(s.cfg.DNSResolver, s.cfg.translateAddressPort, s.cfg.Hosts, s.cfg.Port, s.logger)
263263
if err != nil {
264264
return err
265265
}

0 commit comments

Comments
 (0)