Skip to content

Commit ed1a26d

Browse files
committed
Add support for sql.TxOptions
1 parent 91f10e4 commit ed1a26d

File tree

3 files changed

+106
-6
lines changed

3 files changed

+106
-6
lines changed

conn.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -506,13 +506,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) {
506506
}
507507

508508
func (cn *conn) Begin() (_ driver.Tx, err error) {
509+
return cn.begin("")
510+
}
511+
512+
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
509513
if cn.bad {
510514
return nil, driver.ErrBadConn
511515
}
512516
defer cn.errRecover(&err)
513517

514518
cn.checkIsInTransaction(false)
515-
_, commandTag, err := cn.simpleExec("BEGIN")
519+
_, commandTag, err := cn.simpleExec("BEGIN" + mode)
516520
if err != nil {
517521
return nil, err
518522
}

conn_go18.go

+23-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ package pq
44

55
import (
66
"context"
7+
"database/sql"
78
"database/sql/driver"
8-
"errors"
9+
"fmt"
910
"io"
1011
"io/ioutil"
1112
)
@@ -44,13 +45,30 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
4445

4546
// Implement the "ConnBeginTx" interface
4647
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
47-
if opts.Isolation != 0 {
48-
return nil, errors.New("isolation levels not supported")
48+
var mode string
49+
50+
switch sql.IsolationLevel(opts.Isolation) {
51+
case sql.LevelDefault:
52+
// Don't touch mode: use the server's default
53+
case sql.LevelReadUncommitted:
54+
mode = " ISOLATION LEVEL READ UNCOMMITTED"
55+
case sql.LevelReadCommitted:
56+
mode = " ISOLATION LEVEL READ COMMITTED"
57+
case sql.LevelRepeatableRead:
58+
mode = " ISOLATION LEVEL REPEATABLE READ"
59+
case sql.LevelSerializable:
60+
mode = " ISOLATION LEVEL SERIALIZABLE"
61+
default:
62+
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
4963
}
64+
5065
if opts.ReadOnly {
51-
return nil, errors.New("read-only transactions not supported")
66+
mode += " READ ONLY"
67+
} else {
68+
mode += " READ WRITE"
5269
}
53-
tx, err := cn.Begin()
70+
71+
tx, err := cn.begin(mode)
5472
if err != nil {
5573
return nil, err
5674
}

go18_test.go

+78
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"database/sql"
88
"runtime"
9+
"strings"
910
"testing"
1011
"time"
1112
)
@@ -239,3 +240,80 @@ func TestContextCancelBegin(t *testing.T) {
239240
}
240241
}
241242
}
243+
244+
func TestTxOptions(t *testing.T) {
245+
db := openTestConn(t)
246+
defer db.Close()
247+
ctx := context.Background()
248+
249+
tests := []struct {
250+
level sql.IsolationLevel
251+
isolation string
252+
}{
253+
{
254+
level: sql.LevelDefault,
255+
isolation: "",
256+
},
257+
{
258+
level: sql.LevelReadUncommitted,
259+
isolation: "read uncommitted",
260+
},
261+
{
262+
level: sql.LevelReadCommitted,
263+
isolation: "read committed",
264+
},
265+
{
266+
level: sql.LevelRepeatableRead,
267+
isolation: "repeatable read",
268+
},
269+
{
270+
level: sql.LevelSerializable,
271+
isolation: "serializable",
272+
},
273+
}
274+
275+
for _, test := range tests {
276+
for _, ro := range []bool{true, false} {
277+
tx, err := db.BeginTx(ctx, &sql.TxOptions{
278+
Isolation: test.level,
279+
ReadOnly: ro,
280+
})
281+
if err != nil {
282+
t.Fatal(err)
283+
}
284+
285+
var isolation string
286+
err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation)
287+
if err != nil {
288+
t.Fatal(err)
289+
}
290+
291+
if test.isolation != "" && isolation != test.isolation {
292+
t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation)
293+
}
294+
295+
var isRO string
296+
err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO)
297+
if err != nil {
298+
t.Fatal(err)
299+
}
300+
301+
if ro != (isRO == "on") {
302+
t.Errorf("read/[write,only] not set: %t != %s for level %s",
303+
ro, isRO, test.isolation)
304+
}
305+
306+
tx.Rollback()
307+
}
308+
}
309+
310+
_, err := db.BeginTx(ctx, &sql.TxOptions{
311+
Isolation: sql.LevelLinearizable,
312+
})
313+
if err == nil {
314+
t.Fatal("expected LevelLinearizable to fail")
315+
}
316+
if !strings.Contains(err.Error(), "isolation level not supported") {
317+
t.Errorf("Expected error to mention isolation level, got %q", err)
318+
}
319+
}

0 commit comments

Comments
 (0)