Skip to content

Commit 4810ae9

Browse files
authored
make sql.HashOf() collation aware (#3027)
1 parent 9d8e43b commit 4810ae9

File tree

16 files changed

+228
-147
lines changed

16 files changed

+228
-147
lines changed

enginetest/queries/script_queries.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8732,6 +8732,22 @@ where
87328732
},
87338733
},
87348734
},
8735+
{
8736+
Name: "subquery with case insensitive collation",
8737+
Dialect: "mysql",
8738+
SetUpScript: []string{
8739+
"create table tbl (t text) collate=utf8mb4_0900_ai_ci;",
8740+
"insert into tbl values ('abcdef');",
8741+
},
8742+
Assertions: []ScriptTestAssertion{
8743+
{
8744+
Query: "select 'AbCdEf' in (select t from tbl);",
8745+
Expected: []sql.Row{
8746+
{true},
8747+
},
8748+
},
8749+
},
8750+
},
87358751
}
87368752

87378753
var SpatialScriptTests = []ScriptTest{

memory/table_data.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package memory
1616

1717
import (
18-
"context"
1918
"fmt"
2019
"sort"
2120
"strconv"
@@ -25,6 +24,7 @@ import (
2524

2625
"github.com/dolthub/go-mysql-server/sql"
2726
"github.com/dolthub/go-mysql-server/sql/expression"
27+
"github.com/dolthub/go-mysql-server/sql/hash"
2828
"github.com/dolthub/go-mysql-server/sql/transform"
2929
"github.com/dolthub/go-mysql-server/sql/types"
3030
)
@@ -275,7 +275,7 @@ func (td *TableData) numRows(ctx *sql.Context) (uint64, error) {
275275
}
276276

277277
// throws an error if any two or more rows share the same |cols| values.
278-
func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string, idxName string) error {
278+
func (td *TableData) errIfDuplicateEntryExist(ctx *sql.Context, cols []string, idxName string) error {
279279
columnMapping, err := td.columnIndexes(cols)
280280

281281
// We currently skip validating duplicates on unique virtual columns.
@@ -297,7 +297,7 @@ func (td *TableData) errIfDuplicateEntryExist(ctx context.Context, cols []string
297297
if hasNulls(idxPrefixKey) {
298298
continue
299299
}
300-
h, err := sql.HashOf(ctx, idxPrefixKey)
300+
h, err := hash.HashOf(ctx, td.schema.Schema, idxPrefixKey)
301301
if err != nil {
302302
return err
303303
}

sql/cache.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,12 @@
1515
package sql
1616

1717
import (
18-
"context"
1918
"fmt"
2019
"runtime"
21-
"sync"
22-
23-
"github.com/cespare/xxhash/v2"
2420

2521
lru "github.com/hashicorp/golang-lru"
2622
)
2723

28-
// HashOf returns a hash of the given value to be used as key in a cache.
29-
func HashOf(ctx context.Context, v Row) (uint64, error) {
30-
hash := digestPool.Get().(*xxhash.Digest)
31-
hash.Reset()
32-
defer digestPool.Put(hash)
33-
for i, x := range v {
34-
if i > 0 {
35-
// separate each value in the row with a nil byte
36-
if _, err := hash.Write([]byte{0}); err != nil {
37-
return 0, err
38-
}
39-
}
40-
x, err := UnwrapAny(ctx, x)
41-
if err != nil {
42-
return 0, err
43-
}
44-
// TODO: probably much faster to do this with a type switch
45-
// TODO: we don't have the type info necessary to appropriately encode the value of a string with a non-standard
46-
// collation, which means that two strings that differ only in their collations will hash to the same value.
47-
// See rowexec/grouping_key()
48-
if _, err := fmt.Fprintf(hash, "%v,", x); err != nil {
49-
return 0, err
50-
}
51-
}
52-
return hash.Sum64(), nil
53-
}
54-
55-
var digestPool = sync.Pool{
56-
New: func() any {
57-
return xxhash.New()
58-
},
59-
}
60-
6124
// ErrKeyNotFound is returned when the key could not be found in the cache.
6225
var ErrKeyNotFound = fmt.Errorf("memory: key not found in cache")
6326

sql/cache_test.go

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package sql
1616

1717
import (
18-
"context"
1918
"errors"
2019
"testing"
2120

@@ -178,35 +177,3 @@ func TestRowsCache(t *testing.T) {
178177
require.True(freed)
179178
})
180179
}
181-
182-
func BenchmarkHashOf(b *testing.B) {
183-
ctx := context.Background()
184-
row := NewRow(1, "1")
185-
b.ResetTimer()
186-
for i := 0; i < b.N; i++ {
187-
sum, err := HashOf(ctx, row)
188-
if err != nil {
189-
b.Fatal(err)
190-
}
191-
if sum != 11268758894040352165 {
192-
b.Fatalf("got %v", sum)
193-
}
194-
}
195-
}
196-
197-
func BenchmarkParallelHashOf(b *testing.B) {
198-
ctx := context.Background()
199-
row := NewRow(1, "1")
200-
b.ResetTimer()
201-
b.RunParallel(func(pb *testing.PB) {
202-
for pb.Next() {
203-
sum, err := HashOf(ctx, row)
204-
if err != nil {
205-
b.Fatal(err)
206-
}
207-
if sum != 11268758894040352165 {
208-
b.Fatalf("got %v", sum)
209-
}
210-
}
211-
})
212-
}

sql/hash/hash.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package hash
16+
17+
import (
18+
"fmt"
19+
"sync"
20+
21+
"github.com/cespare/xxhash/v2"
22+
23+
"github.com/dolthub/go-mysql-server/sql"
24+
"github.com/dolthub/go-mysql-server/sql/types"
25+
)
26+
27+
var digestPool = sync.Pool{
28+
New: func() any {
29+
return xxhash.New()
30+
},
31+
}
32+
33+
// HashOf returns a hash of the given value to be used as key in a cache.
34+
func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) {
35+
hash := digestPool.Get().(*xxhash.Digest)
36+
hash.Reset()
37+
defer digestPool.Put(hash)
38+
for i, v := range row {
39+
if i > 0 {
40+
// separate each value in the row with a nil byte
41+
if _, err := hash.Write([]byte{0}); err != nil {
42+
return 0, err
43+
}
44+
}
45+
46+
v, err := sql.UnwrapAny(ctx, v)
47+
if err != nil {
48+
return 0, fmt.Errorf("error unwrapping value: %w", err)
49+
}
50+
51+
// TODO: we may not always have the type information available, so we check schema length.
52+
// Then, defer to original behavior
53+
if i >= len(sch) || v == nil {
54+
_, err := fmt.Fprintf(hash, "%v", v)
55+
if err != nil {
56+
return 0, err
57+
}
58+
continue
59+
}
60+
61+
switch typ := sch[i].Type.(type) {
62+
case types.ExtendedType:
63+
// TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL,
64+
// so we're using the old (probably incorrect) behavior for now
65+
_, err = fmt.Fprintf(hash, "%v", v)
66+
if err != nil {
67+
return 0, err
68+
}
69+
case types.StringType:
70+
var strVal string
71+
strVal, err = types.ConvertToString(ctx, v, typ, nil)
72+
if err != nil {
73+
return 0, err
74+
}
75+
err = typ.Collation().WriteWeightString(hash, strVal)
76+
if err != nil {
77+
return 0, err
78+
}
79+
default:
80+
// TODO: probably much faster to do this with a type switch
81+
_, err = fmt.Fprintf(hash, "%v", v)
82+
if err != nil {
83+
return 0, err
84+
}
85+
}
86+
}
87+
return hash.Sum64(), nil
88+
}

sql/hash/hash_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package hash
16+
17+
import (
18+
"testing"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
)
22+
23+
func BenchmarkHashOf(b *testing.B) {
24+
ctx := sql.NewEmptyContext()
25+
row := sql.NewRow(1, "1")
26+
b.ResetTimer()
27+
for i := 0; i < b.N; i++ {
28+
sum, err := HashOf(ctx, nil, row)
29+
if err != nil {
30+
b.Fatal(err)
31+
}
32+
if sum != 11268758894040352165 {
33+
b.Fatalf("got %v", sum)
34+
}
35+
}
36+
}
37+
38+
func BenchmarkParallelHashOf(b *testing.B) {
39+
ctx := sql.NewEmptyContext()
40+
row := sql.NewRow(1, "1")
41+
b.ResetTimer()
42+
b.RunParallel(func(pb *testing.PB) {
43+
for pb.Next() {
44+
sum, err := HashOf(ctx, nil, row)
45+
if err != nil {
46+
b.Fatal(err)
47+
}
48+
if sum != 11268758894040352165 {
49+
b.Fatalf("got %v", sum)
50+
}
51+
}
52+
})
53+
}

sql/iters/rel_iters.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/dolthub/go-mysql-server/sql"
2626
"github.com/dolthub/go-mysql-server/sql/expression"
27+
"github.com/dolthub/go-mysql-server/sql/hash"
2728
"github.com/dolthub/go-mysql-server/sql/types"
2829
)
2930

@@ -571,7 +572,7 @@ func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) {
571572
return nil, err
572573
}
573574

574-
hash, err := sql.HashOf(ctx, row)
575+
hash, err := hash.HashOf(ctx, nil, row)
575576
if err != nil {
576577
return nil, err
577578
}
@@ -643,22 +644,21 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) {
643644
ii.cache = make(map[uint64]int)
644645
for {
645646
res, err := ii.RIter.Next(ctx)
646-
if err != nil && err != io.EOF {
647+
if err != nil {
648+
if err == io.EOF {
649+
break
650+
}
647651
return nil, err
648652
}
649653

650-
hash, herr := sql.HashOf(ctx, res)
654+
hash, herr := hash.HashOf(ctx, nil, res)
651655
if herr != nil {
652656
return nil, herr
653657
}
654658
if _, ok := ii.cache[hash]; !ok {
655659
ii.cache[hash] = 0
656660
}
657661
ii.cache[hash]++
658-
659-
if err == io.EOF {
660-
break
661-
}
662662
}
663663
ii.cached = true
664664
}
@@ -669,7 +669,7 @@ func (ii *IntersectIter) Next(ctx *sql.Context) (sql.Row, error) {
669669
return nil, err
670670
}
671671

672-
hash, herr := sql.HashOf(ctx, res)
672+
hash, herr := hash.HashOf(ctx, nil, res)
673673
if herr != nil {
674674
return nil, herr
675675
}
@@ -714,7 +714,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) {
714714
return nil, err
715715
}
716716

717-
hash, herr := sql.HashOf(ctx, res)
717+
hash, herr := hash.HashOf(ctx, nil, res)
718718
if herr != nil {
719719
return nil, herr
720720
}
@@ -736,7 +736,7 @@ func (ei *ExceptIter) Next(ctx *sql.Context) (sql.Row, error) {
736736
return nil, err
737737
}
738738

739-
hash, herr := sql.HashOf(ctx, res)
739+
hash, herr := hash.HashOf(ctx, nil, res)
740740
if herr != nil {
741741
return nil, herr
742742
}

sql/plan/hash_lookup.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ import (
1818
"fmt"
1919
"sync"
2020

21-
"github.com/dolthub/go-mysql-server/sql/types"
22-
2321
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/hash"
23+
"github.com/dolthub/go-mysql-server/sql/types"
2424
)
2525

2626
// NewHashLookup returns a node that performs an indexed hash lookup
@@ -127,7 +127,7 @@ func (n *HashLookup) GetHashKey(ctx *sql.Context, e sql.Expression, row sql.Row)
127127
return nil, err
128128
}
129129
if s, ok := key.([]interface{}); ok {
130-
return sql.HashOf(ctx, s)
130+
return hash.HashOf(ctx, n.Schema(), s)
131131
}
132132
// byte slices are not hashable
133133
if k, ok := key.([]byte); ok {

0 commit comments

Comments
 (0)