Skip to content

Commit abe07b6

Browse files
authored
support auth mechanism (#51)
* support auth mechnism * polish code * follow comment
1 parent 1e12c84 commit abe07b6

File tree

4 files changed

+90
-22
lines changed

4 files changed

+90
-22
lines changed

driver.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7-
"errors"
8-
"github.com/apache/thrift/lib/go/thrift"
7+
"fmt"
98

9+
"github.com/apache/thrift/lib/go/thrift"
10+
bgohive "github.com/beltran/gohive"
1011
hiveserver2 "github.com/sql-machine-learning/gohive/hiveserver2/gen-go/tcliservice"
1112
)
1213

@@ -17,23 +18,36 @@ func (d drv) Open(dsn string) (driver.Conn, error) {
1718
if err != nil {
1819
return nil, err
1920
}
20-
transport, err := thrift.NewTSocket(cfg.Addr)
21+
socket, err := thrift.NewTSocket(cfg.Addr)
2122
if err != nil {
2223
return nil, err
2324
}
24-
25-
if err := transport.Open(); err != nil {
26-
return nil, err
25+
var transport thrift.TTransport
26+
if cfg.Auth == "NOSASL" {
27+
transport = thrift.NewTBufferedTransport(socket, 4096)
28+
if transport == nil {
29+
return nil, fmt.Errorf("BufferedTransport is nil")
30+
}
31+
} else if cfg.Auth == "PLAIN" || cfg.Auth == "GSSAPI" || cfg.Auth == "LDAP" {
32+
saslCfg := map[string]string{
33+
"username": cfg.User,
34+
"password": cfg.Passwd,
35+
}
36+
transport, err = bgohive.NewTSaslTransport(socket, cfg.Addr, cfg.Auth, saslCfg)
37+
if err != nil {
38+
return nil, fmt.Errorf("create SasalTranposrt failed: %v", err)
39+
}
40+
} else {
41+
return nil, fmt.Errorf("unrecognized auth mechanism: %s", cfg.Auth)
2742
}
28-
29-
if transport == nil {
30-
return nil, errors.New("nil thrift transport")
43+
if err = transport.Open(); err != nil {
44+
return nil, err
3145
}
3246

3347
protocol := thrift.NewTBinaryProtocolFactoryDefault()
3448
client := hiveserver2.NewTCLIServiceClientFactory(transport, protocol)
3549
s := hiveserver2.NewTOpenSessionReq()
36-
s.ClientProtocol = 6
50+
s.ClientProtocol = hiveserver2.TProtocolVersion_HIVE_CLI_SERVICE_PROTOCOL_V6
3751
if cfg.User != "" {
3852
s.Username = &cfg.User
3953
if cfg.Passwd != "" {

driver_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ func TestOpenConnection(t *testing.T) {
1414
defer db.Close()
1515
}
1616

17+
func TestOpenConnectionAgainstAuth(t *testing.T) {
18+
db, _ := sql.Open("hive", "127.0.0.1:10000/churn?auth=PLAIN")
19+
rows, err := db.Query("SELECT customerID, gender FROM train")
20+
assert.EqualError(t, err, "Bad SASL negotiation status: 4 ()")
21+
defer db.Close()
22+
if err == nil {
23+
defer rows.Close()
24+
}
25+
}
26+
1727
func TestQuery(t *testing.T) {
1828
db, _ := sql.Open("hive", "127.0.0.1:10000/churn")
1929
rows, err := db.Query("SELECT customerID, gender FROM train")

dsn.go

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,22 @@ type Config struct {
1111
Passwd string
1212
Addr string
1313
DBName string
14+
Auth string
1415
}
1516

1617
var (
1718
// Regexp syntax: https://github.com/google/re2/wiki/Syntax
18-
reDSN = regexp.MustCompile(`(.+@)?([^@]+)`)
19+
reDSN = regexp.MustCompile(`(.+@)?([^@|^?]+)\\?(.*)`)
1920
reUserPasswd = regexp.MustCompile(`([^:@]+)(:[^:@]+)?@`)
21+
reArguments = regexp.MustCompile(`(\w+)=(\w+)`)
2022
)
2123

2224
// ParseDSN requires DSN names in the format [user[:password]@]addr/dbname.
2325
func ParseDSN(dsn string) (*Config, error) {
2426
// Please read https://play.golang.org/p/_CSLvl1AxOX before code review.
2527
sub := reDSN.FindStringSubmatch(dsn)
26-
if len(sub) != 3 {
27-
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname]", dsn)
28+
if len(sub) != 4 {
29+
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
2830
}
2931
addr := ""
3032
dbname := ""
@@ -35,21 +37,45 @@ func ParseDSN(dsn string) (*Config, error) {
3537
} else {
3638
addr = sub[2]
3739
}
40+
user := ""
41+
passwd := ""
42+
auth := "NOSASL"
3843
up := reUserPasswd.FindStringSubmatch(sub[1])
3944
if len(up) == 3 {
45+
user = up[1]
4046
if len(up[2]) > 0 {
41-
return &Config{User: up[1], Passwd: up[2][1:], Addr: addr, DBName: dbname}, nil
47+
passwd = up[2][1:]
4248
}
43-
return &Config{User: up[1], Addr: addr, DBName: dbname}, nil
4449
}
45-
return &Config{Addr: addr, DBName: dbname}, nil
50+
51+
args := reArguments.FindAllStringSubmatch(sub[3], -1)
52+
if len(args) > 1 {
53+
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
54+
}
55+
56+
if len(args) == 1 {
57+
if args[0][1] != "auth" {
58+
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
59+
}
60+
auth = args[0][2]
61+
}
62+
return &Config{
63+
User: user,
64+
Passwd: passwd,
65+
Addr: addr,
66+
DBName: dbname,
67+
Auth: auth,
68+
}, nil
4669
}
4770

48-
// FormatDSN outputs a string in the format "user:password@address"
71+
// FormatDSN outputs a string in the format "user:password@address?auth=xxx"
4972
func (cfg *Config) FormatDSN() string {
73+
dsn := fmt.Sprintf("%s:%s@%s", cfg.User, cfg.Passwd, cfg.Addr)
5074
if len(cfg.DBName) > 0 {
51-
return fmt.Sprintf("%s:%s@%s/%s", cfg.User, cfg.Passwd, cfg.Addr, cfg.DBName)
52-
} else {
53-
return fmt.Sprintf("%s:%s@%s", cfg.User, cfg.Passwd, cfg.Addr)
75+
dsn = fmt.Sprintf("%s/%s", dsn, cfg.DBName)
76+
}
77+
if len(cfg.Auth) > 0 {
78+
dsn = fmt.Sprintf("%s?auth=%s", dsn, cfg.Auth)
5479
}
80+
return dsn
5581
}

dsn_test.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ import (
66
"github.com/stretchr/testify/assert"
77
)
88

9+
func TestParseDSNWithAuth(t *testing.T) {
10+
cfg, e := ParseDSN("root:[email protected]/mnist?auth=PLAIN")
11+
assert.Nil(t, e)
12+
assert.Equal(t, cfg.User, "root")
13+
assert.Equal(t, cfg.Passwd, "root")
14+
assert.Equal(t, cfg.Addr, "127.0.0.1")
15+
assert.Equal(t, cfg.DBName, "mnist")
16+
assert.Equal(t, cfg.Auth, "PLAIN")
17+
18+
cfg, e = ParseDSN("[email protected]/mnist")
19+
assert.Nil(t, e)
20+
assert.Equal(t, cfg.User, "root")
21+
assert.Equal(t, cfg.Passwd, "")
22+
assert.Equal(t, cfg.Addr, "127.0.0.1")
23+
assert.Equal(t, cfg.DBName, "mnist")
24+
assert.Equal(t, cfg.Auth, "NOSASL")
25+
}
26+
927
func TestParseDSNWithDBName(t *testing.T) {
1028
cfg, e := ParseDSN("root:[email protected]/mnist")
1129
assert.Nil(t, e)
@@ -50,7 +68,7 @@ func TestParseDSNWithoutDBName(t *testing.T) {
5068
}
5169

5270
func TestFormatDSNWithDBName(t *testing.T) {
53-
ds := "user:[email protected]/mnist"
71+
ds := "user:[email protected]/mnist?auth=NOSASL"
5472
cfg, e := ParseDSN(ds)
5573
assert.Nil(t, e)
5674

@@ -59,7 +77,7 @@ func TestFormatDSNWithDBName(t *testing.T) {
5977
}
6078

6179
func TestFormatDSNWithoutDBName(t *testing.T) {
62-
ds := "user:[email protected]"
80+
ds := "user:[email protected]?auth=NOSASL"
6381
cfg, e := ParseDSN(ds)
6482
assert.Nil(t, e)
6583

0 commit comments

Comments
 (0)