@@ -8,33 +8,36 @@ import (
8
8
"sqlflow.org/gohive/hiveserver2"
9
9
)
10
10
11
- // Options for opened Hive sessions.
12
- type Options struct {
11
+ // hiveOptions for opened Hive sessions.
12
+ type hiveOptions struct {
13
13
PollIntervalSeconds int64
14
14
BatchSize int64
15
15
}
16
16
17
- type Connection struct {
17
+ type hiveConnection struct {
18
18
thrift * hiveserver2.TCLIServiceClient
19
19
session * hiveserver2.TSessionHandle
20
- options Options
20
+ options hiveOptions
21
21
}
22
22
23
- func (c * Connection ) Begin () (driver.Tx , error ) {
23
+ func (c * hiveConnection ) Begin () (driver.Tx , error ) {
24
24
return nil , nil
25
25
}
26
26
27
- func (c * Connection ) Prepare (query string ) (driver.Stmt , error ) {
28
- return nil , nil
27
+ func (c * hiveConnection ) Prepare (qry string ) (driver.Stmt , error ) {
28
+ if ! c .isOpen () {
29
+ return nil , fmt .Errorf ("driver: bad connection" )
30
+ }
31
+ return & hiveStmt {hc : c , query : qry }, nil
29
32
}
30
33
31
- func (c * Connection ) isOpen () bool {
34
+ func (c * hiveConnection ) isOpen () bool {
32
35
return c .session != nil
33
36
}
34
37
35
38
// As hiveserver2 thrift api does not provide Ping method,
36
39
// we use GetInfo instead to check the health of hiveserver2.
37
- func (c * Connection ) Ping (ctx context.Context ) (err error ) {
40
+ func (c * hiveConnection ) Ping (ctx context.Context ) (err error ) {
38
41
getInfoReq := hiveserver2 .NewTGetInfoReq ()
39
42
getInfoReq .SessionHandle = c .session
40
43
getInfoReq .InfoType = hiveserver2 .TGetInfoType_CLI_SERVER_NAME
@@ -52,7 +55,7 @@ func (c *Connection) Ping(ctx context.Context) (err error) {
52
55
return nil
53
56
}
54
57
55
- func (c * Connection ) Close () error {
58
+ func (c * hiveConnection ) Close () error {
56
59
if c .isOpen () {
57
60
closeReq := hiveserver2 .NewTCloseSessionReq ()
58
61
closeReq .SessionHandle = c .session
@@ -74,7 +77,7 @@ func removeLastSemicolon(s string) string {
74
77
return s
75
78
}
76
79
77
- func (c * Connection ) execute (ctx context.Context , query string , args []driver.NamedValue ) (* hiveserver2.TExecuteStatementResp , error ) {
80
+ func (c * hiveConnection ) execute (ctx context.Context , query string , args []driver.NamedValue ) (* hiveserver2.TExecuteStatementResp , error ) {
78
81
executeReq := hiveserver2 .NewTExecuteStatementReq ()
79
82
executeReq .SessionHandle = c .session
80
83
executeReq .Statement = removeLastSemicolon (query )
@@ -90,15 +93,15 @@ func (c *Connection) execute(ctx context.Context, query string, args []driver.Na
90
93
return resp , nil
91
94
}
92
95
93
- func (c * Connection ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
96
+ func (c * hiveConnection ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
94
97
resp , err := c .execute (ctx , query , args )
95
98
if err != nil {
96
99
return nil , err
97
100
}
98
101
return newRows (c .thrift , resp .OperationHandle , c .options ), nil
99
102
}
100
103
101
- func (c * Connection ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
104
+ func (c * hiveConnection ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
102
105
resp , err := c .execute (ctx , query , args )
103
106
if err != nil {
104
107
return nil , err
0 commit comments