Skip to content

Commit ce78b51

Browse files
committed
Merge upstream changes to commit '0aa180bad'
2 parents 1a9cec5 + 0aa180b commit ce78b51

File tree

8 files changed

+304
-7
lines changed

8 files changed

+304
-7
lines changed

cassandra_test.go

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,92 @@ func TestCAS(t *testing.T) {
690690
}
691691
}
692692

693+
func TestConsistencySerial(t *testing.T) {
694+
session := createSession(t)
695+
defer session.Close()
696+
697+
type testStruct struct {
698+
name string
699+
id int
700+
consistency Consistency
701+
expectedPanicValue string
702+
}
703+
704+
testCases := []testStruct{
705+
{
706+
name: "Any",
707+
consistency: Any,
708+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got ANY",
709+
}, {
710+
name: "One",
711+
consistency: One,
712+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got ONE",
713+
}, {
714+
name: "Two",
715+
consistency: Two,
716+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got TWO",
717+
}, {
718+
name: "Three",
719+
consistency: Three,
720+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got THREE",
721+
}, {
722+
name: "Quorum",
723+
consistency: Quorum,
724+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got QUORUM",
725+
}, {
726+
name: "LocalQuorum",
727+
consistency: LocalQuorum,
728+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got LOCAL_QUORUM",
729+
}, {
730+
name: "EachQuorum",
731+
consistency: EachQuorum,
732+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got EACH_QUORUM",
733+
}, {
734+
name: "Serial",
735+
id: 8,
736+
consistency: Serial,
737+
expectedPanicValue: "",
738+
}, {
739+
name: "LocalSerial",
740+
id: 9,
741+
consistency: LocalSerial,
742+
expectedPanicValue: "",
743+
}, {
744+
name: "LocalOne",
745+
consistency: LocalOne,
746+
expectedPanicValue: "Serial consistency can only be SERIAL or LOCAL_SERIAL got LOCAL_ONE",
747+
},
748+
}
749+
750+
err := session.Query("CREATE TABLE IF NOT EXISTS gocql_test.consistency_serial (id int PRIMARY KEY)").Exec()
751+
if err != nil {
752+
t.Fatalf("can't create consistency_serial table:%v", err)
753+
}
754+
755+
for _, tc := range testCases {
756+
t.Run(tc.name, func(t *testing.T) {
757+
if tc.expectedPanicValue == "" {
758+
err = session.Query("INSERT INTO gocql_test.consistency_serial (id) VALUES (?)", tc.id).SerialConsistency(tc.consistency).Exec()
759+
if err != nil {
760+
t.Fatal(err)
761+
}
762+
763+
var receivedID int
764+
err = session.Query("SELECT * FROM gocql_test.consistency_serial WHERE id=?", tc.id).Scan(&receivedID)
765+
if err != nil {
766+
t.Fatal(err)
767+
}
768+
769+
require.Equal(t, tc.id, receivedID)
770+
} else {
771+
require.PanicsWithValue(t, tc.expectedPanicValue, func() {
772+
session.Query("INSERT INTO gocql_test.consistency_serial (id) VALUES (?)", tc.id).SerialConsistency(tc.consistency)
773+
})
774+
}
775+
})
776+
}
777+
}
778+
693779
func TestDurationType(t *testing.T) {
694780
session := createSession(t)
695781
defer session.Close()
@@ -2779,7 +2865,6 @@ func TestUnsetColBatch(t *testing.T) {
27792865
}
27802866
var id, mInt, count int
27812867
var mText string
2782-
27832868
if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
27842869
t.Fatalf("Failed to select with err: %v", err)
27852870
} else if count != 2 {
@@ -2814,3 +2899,52 @@ func TestQuery_NamedValues(t *testing.T) {
28142899
t.Fatal(err)
28152900
}
28162901
}
2902+
2903+
// This test ensures that queries are sent to the specified host only
2904+
func TestQuery_SetHostID(t *testing.T) {
2905+
session := createSession(t)
2906+
defer session.Close()
2907+
2908+
hosts := session.GetHosts()
2909+
2910+
const iterations = 5
2911+
for _, expectedHost := range hosts {
2912+
for i := 0; i < iterations; i++ {
2913+
var actualHostID string
2914+
err := session.Query("SELECT host_id FROM system.local").
2915+
SetHostID(expectedHost.HostID()).
2916+
Scan(&actualHostID)
2917+
if err != nil {
2918+
t.Fatal(err)
2919+
}
2920+
2921+
if expectedHost.HostID() != actualHostID {
2922+
t.Fatalf("Expected query to be executed on host %s, but it was executed on %s",
2923+
expectedHost.HostID(),
2924+
actualHostID,
2925+
)
2926+
}
2927+
}
2928+
}
2929+
2930+
// ensuring properly handled invalid host id
2931+
err := session.Query("SELECT host_id FROM system.local").
2932+
SetHostID("[invalid]").
2933+
Exec()
2934+
if !errors.Is(err, ErrNoPool) {
2935+
t.Fatalf("Expected error to be: %v, but got %v", ErrNoPool, err)
2936+
}
2937+
2938+
// ensuring that the driver properly handles the case
2939+
// when specified host for the query is down
2940+
host := hosts[0]
2941+
pool, _ := session.pool.getPoolByHostID(host.HostID())
2942+
// simulating specified host is down
2943+
pool.host.setState(NodeDown)
2944+
err = session.Query("SELECT host_id FROM system.local").
2945+
SetHostID(host.HostID()).
2946+
Exec()
2947+
if !errors.Is(err, ErrHostDown) {
2948+
t.Fatalf("Expected error to be: %v, but got %v", ErrHostDown, err)
2949+
}
2950+
}

cluster.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ type ClusterConfig struct {
173173

174174
// Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL.
175175
// Default: unset
176-
SerialConsistency SerialConsistency
176+
SerialConsistency Consistency
177177

178178
// SslOpts configures TLS use when HostDialer is not set.
179179
// SslOpts is ignored if HostDialer is set.
@@ -503,6 +503,10 @@ func (cfg *ClusterConfig) Validate() error {
503503
cfg.Logger.Println("warning: enabling skipping metadata can lead to unpredictible results when executing query and altering columns involved in the query.")
504504
}
505505

506+
if cfg.SerialConsistency > 0 && !cfg.SerialConsistency.IsSerial() {
507+
return fmt.Errorf("the default SerialConsistency level is not allowed to be anything else but SERIAL or LOCAL_SERIAL. Recived value: %v", cfg.SerialConsistency)
508+
}
509+
506510
return cfg.ValidateAndInitSSL()
507511
}
508512

connectionpool.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
196196
return
197197
}
198198

199+
func (p *policyConnPool) getPoolByHostID(hostID string) (pool *hostConnPool, ok bool) {
200+
p.mu.RLock()
201+
pool, ok = p.hostConnPools[hostID]
202+
p.mu.RUnlock()
203+
return
204+
}
205+
199206
func (p *policyConnPool) Close() {
200207
p.mu.Lock()
201208
defer p.mu.Unlock()

frame.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ func (c *Consistency) UnmarshalText(text []byte) error {
273273
default:
274274
return fmt.Errorf("invalid consistency %q", string(text))
275275
}
276-
277276
return nil
278277
}
279278

policies.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,34 @@ func (host selectedHost) Token() Token {
390390

391391
func (host selectedHost) Mark(err error) {}
392392

393+
func newSingleHost(info *HostInfo, maxRetries byte, retryDelay time.Duration) *singleHost {
394+
return &singleHost{info: info, maxRetries: maxRetries, delay: retryDelay}
395+
}
396+
397+
type singleHost struct {
398+
retry byte
399+
maxRetries byte
400+
delay time.Duration
401+
info *HostInfo
402+
}
403+
404+
func (s *singleHost) selectHost() SelectedHost {
405+
if s.retry >= s.maxRetries {
406+
return nil
407+
}
408+
if s.retry > 0 && s.delay > 0 {
409+
time.Sleep(s.delay)
410+
}
411+
s.retry++
412+
return s
413+
}
414+
415+
func (s singleHost) Info() *HostInfo { return s.info }
416+
417+
func (s singleHost) Token() Token { return nil }
418+
419+
func (s singleHost) Mark(error) {}
420+
393421
// NextHost is an iteration function over picked hosts
394422
type NextHost func() SelectedHost
395423

policies_integration_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
package gocql
55

66
import (
7+
"context"
78
"testing"
9+
"time"
810
)
911

1012
// Check if session fail to start if DC name provided in the policy is wrong
@@ -39,3 +41,64 @@ func TestDCValidationRackAware(t *testing.T) {
3941
t.Fatal("createSession was expected to fail with wrong DC name provided.")
4042
}
4143
}
44+
45+
// This test ensures that when all hosts are down, the query execution does not hang.
46+
func TestNoHangAllHostsDown(t *testing.T) {
47+
cluster := createCluster()
48+
session := createSessionFromCluster(cluster, t)
49+
50+
hosts := session.GetHosts()
51+
dc := hosts[0].DataCenter()
52+
rack := hosts[0].Rack()
53+
session.Close()
54+
55+
policies := []HostSelectionPolicy{
56+
DCAwareRoundRobinPolicy(dc),
57+
DCAwareRoundRobinPolicy(dc, HostPolicyOptionDisableDCFailover),
58+
TokenAwareHostPolicy(DCAwareRoundRobinPolicy(dc)),
59+
TokenAwareHostPolicy(DCAwareRoundRobinPolicy(dc, HostPolicyOptionDisableDCFailover)),
60+
RackAwareRoundRobinPolicy(dc, rack),
61+
RackAwareRoundRobinPolicy(dc, rack, HostPolicyOptionDisableDCFailover),
62+
TokenAwareHostPolicy(RackAwareRoundRobinPolicy(dc, rack)),
63+
TokenAwareHostPolicy(RackAwareRoundRobinPolicy(dc, rack, HostPolicyOptionDisableDCFailover)),
64+
nil,
65+
}
66+
67+
for _, policy := range policies {
68+
cluster = createCluster()
69+
cluster.PoolConfig.HostSelectionPolicy = policy
70+
session = createSessionFromCluster(cluster, t)
71+
hosts = session.GetHosts()
72+
73+
// simulating hosts are down
74+
for _, host := range hosts {
75+
pool, _ := session.pool.getPoolByHostID(host.HostID())
76+
pool.host.setState(NodeDown)
77+
if policy != nil {
78+
policy.AddHost(host)
79+
}
80+
}
81+
82+
ctx, _ := context.WithTimeout(context.Background(), 12*time.Second)
83+
_ = session.Query("SELECT host_id FROM system.local").WithContext(ctx).Exec()
84+
if ctx.Err() != nil {
85+
t.Errorf("policy %T should be no hangups when all hosts are down", policy)
86+
}
87+
88+
// remove all host except one
89+
if policy != nil {
90+
for i, host := range hosts {
91+
if i != 0 {
92+
policy.RemoveHost(host)
93+
}
94+
}
95+
}
96+
97+
ctx, _ = context.WithTimeout(context.Background(), 12*time.Second)
98+
_ = session.Query("SELECT host_id FROM system.local").WithContext(ctx).Exec()
99+
if ctx.Err() != nil {
100+
t.Errorf("policy %T should be no hangups when all hosts are down", policy)
101+
}
102+
session.Close()
103+
}
104+
}

query_executor.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ package gocql
2727
import (
2828
"context"
2929
"errors"
30+
"fmt"
3031
"sync"
3132
"time"
3233
)
@@ -44,6 +45,7 @@ type ExecutableQuery interface {
4445
IsIdempotent() bool
4546
IsLWT() bool
4647
GetCustomPartitioner() Partitioner
48+
GetHostID() string
4749

4850
withContext(context.Context) ExecutableQuery
4951

@@ -88,12 +90,29 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S
8890
}
8991

9092
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
91-
hostIter := q.policy.Pick(qry)
93+
var hostIter NextHost
94+
95+
// check if the hostID is specified for the query,
96+
// if true - the query execute at the specified host.
97+
// if false - the query execute at the host picked by HostSelectionPolicy
98+
if hostID := qry.GetHostID(); hostID != "" {
99+
pool, ok := q.pool.getPoolByHostID(hostID)
100+
if !ok {
101+
// if the specified host ID have no connection pool we return error
102+
return nil, fmt.Errorf("query is targeting unknown host id %s: %w", hostID, ErrNoPool)
103+
} else if pool.Size() == 0 {
104+
// if the pool have no connection we return error
105+
return nil, fmt.Errorf("query is targeting host id %s that driver is not connected to: %w", hostID, ErrNoConnectionsInPool)
106+
}
107+
hostIter = newSingleHost(pool.host, 5, 200*time.Millisecond).selectHost
108+
} else {
109+
hostIter = q.policy.Pick(qry)
110+
}
92111

93112
// check if the query is not marked as idempotent, if
94113
// it is, we force the policy to NonSpeculative
95114
sp := qry.speculativeExecutionPolicy()
96-
if !qry.IsIdempotent() || sp.Attempts() == 0 {
115+
if qry.GetHostID() != "" || !qry.IsIdempotent() || sp.Attempts() == 0 {
97116
return q.do(qry.Context(), qry, hostIter), nil
98117
}
99118

0 commit comments

Comments
 (0)