Skip to content

Commit 0350a00

Browse files
committed
[sqlserver] Always access version table with explicit schema
1 parent 8147693 commit 0350a00

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

database/sqlserver/sqlserver.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
263263
return &database.Error{OrigErr: err, Err: "transaction start failed"}
264264
}
265265

266-
query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
266+
query := `TRUNCATE TABLE ` + ss.getMigrationTable()
267267
if _, err := tx.Exec(query); err != nil {
268268
if errRollback := tx.Rollback(); errRollback != nil {
269269
err = multierror.Append(err, errRollback)
@@ -279,7 +279,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
279279
if dirty {
280280
dirtyBit = 1
281281
}
282-
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
282+
query = `INSERT INTO ` + ss.getMigrationTable() + ` (version, dirty) VALUES (@p1, @p2)`
283283
if _, err := tx.Exec(query, version, dirtyBit); err != nil {
284284
if errRollback := tx.Rollback(); errRollback != nil {
285285
err = multierror.Append(err, errRollback)
@@ -297,7 +297,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error {
297297

298298
// Version of the current database state
299299
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
300-
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
300+
query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable()
301301
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
302302
switch {
303303
case err == sql.ErrNoRows:
@@ -365,10 +365,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) {
365365
query := `IF NOT EXISTS
366366
(SELECT *
367367
FROM sysobjects
368-
WHERE id = object_id(N'[` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `]')
368+
WHERE id = object_id(N'` + ss.getMigrationTable() + `')
369369
AND OBJECTPROPERTY(id, N'IsUserTable') = 1
370370
)
371-
CREATE TABLE [` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `] ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
371+
CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
372372

373373
if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
374374
return &database.Error{OrigErr: err, Query: []byte(query)}
@@ -377,6 +377,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) {
377377
return nil
378378
}
379379

380+
func (ss *SQLServer) getMigrationTable() string {
381+
return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable)
382+
}
383+
380384
func getMSITokenProvider(resource string) (func() (string, error), error) {
381385
msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
382386
if err != nil {

0 commit comments

Comments
 (0)