Skip to content

Commit caaad38

Browse files
committed
refactor: part 2 of distinguish between Unique and UniqueIndex
1 parent 46816ad commit caaad38

File tree

3 files changed

+234
-20
lines changed

3 files changed

+234
-20
lines changed

migrator/migrator.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
9393
expr.SQL += " NOT NULL"
9494
}
9595

96-
if field.Unique {
97-
expr.SQL += " UNIQUE"
98-
}
99-
10096
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
10197
if field.DefaultValueInterface != nil {
10298
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
@@ -512,14 +508,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
512508
}
513509
}
514510

515-
// check unique
516-
if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") {
517-
// not primary key
518-
if !field.PrimaryKey {
519-
alterColumn = true
520-
}
521-
}
522-
523511
// check default value
524512
if !field.PrimaryKey {
525513
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
@@ -548,8 +536,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
548536
}
549537
}
550538

551-
if alterColumn && !field.IgnoreMigration {
552-
return m.DB.Migrator().AlterColumn(value, field.DBName)
539+
if alterColumn {
540+
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
541+
return err
542+
}
543+
}
544+
545+
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
546+
return err
553547
}
554548

555549
return nil

tests/go.mod

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,34 @@ module gorm.io/gorm/tests
33
go 1.18
44

55
require (
6-
github.com/google/uuid v1.5.0
6+
github.com/google/uuid v1.6.0
77
github.com/jinzhu/now v1.1.5
88
github.com/lib/pq v1.10.9
9+
github.com/stretchr/testify v1.8.4
910
gorm.io/driver/mysql v1.5.2
1011
gorm.io/driver/postgres v1.5.4
1112
gorm.io/driver/sqlite v1.5.4
1213
gorm.io/driver/sqlserver v1.5.2
13-
gorm.io/gorm v1.25.5
14+
gorm.io/gorm v1.25.6
1415
)
1516

1617
require (
18+
github.com/davecgh/go-spew v1.1.1 // indirect
1719
github.com/go-sql-driver/mysql v1.7.1 // indirect
1820
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
1921
github.com/golang-sql/sqlexp v0.1.0 // indirect
2022
github.com/jackc/pgpassfile v1.0.0 // indirect
2123
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
22-
github.com/jackc/pgx/v5 v5.5.1 // indirect
24+
github.com/jackc/pgx/v5 v5.5.3 // indirect
2325
github.com/jinzhu/inflection v1.0.0 // indirect
24-
github.com/mattn/go-sqlite3 v1.14.19 // indirect
26+
github.com/kr/text v0.2.0 // indirect
27+
github.com/mattn/go-sqlite3 v1.14.22 // indirect
2528
github.com/microsoft/go-mssqldb v1.6.0 // indirect
29+
github.com/pmezard/go-difflib v1.0.0 // indirect
30+
github.com/rogpeppe/go-internal v1.12.0 // indirect
2631
golang.org/x/crypto v0.18.0 // indirect
2732
golang.org/x/text v0.14.0 // indirect
33+
gopkg.in/yaml.v3 v3.0.1 // indirect
2834
)
2935

3036
replace gorm.io/gorm => ../

tests/migrate_test.go

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tests_test
22

33
import (
44
"context"
5+
"database/sql"
56
"fmt"
67
"math/rand"
78
"os"
@@ -10,10 +11,15 @@ import (
1011
"testing"
1112
"time"
1213

14+
"github.com/stretchr/testify/assert"
1315
"gorm.io/driver/postgres"
16+
1417
"gorm.io/gorm"
18+
"gorm.io/gorm/clause"
1519
"gorm.io/gorm/logger"
20+
"gorm.io/gorm/migrator"
1621
"gorm.io/gorm/schema"
22+
"gorm.io/gorm/utils"
1723
. "gorm.io/gorm/utils/tests"
1824
)
1925

@@ -984,7 +990,8 @@ func TestCurrentTimestamp(t *testing.T) {
984990
if err != nil {
985991
t.Fatalf("AutoMigrate err:%v", err)
986992
}
987-
AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
993+
AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at"))
994+
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
988995
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
989996
}
990997

@@ -1046,7 +1053,8 @@ func TestUniqueColumn(t *testing.T) {
10461053
}
10471054

10481055
// not trigger alert column
1049-
AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
1056+
AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name"))
1057+
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
10501058
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1"))
10511059
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2"))
10521060

@@ -1712,3 +1720,209 @@ func TestTableType(t *testing.T) {
17121720
t.Fatalf("expected comment %s got %s", tblComment, comment)
17131721
}
17141722
}
1723+
1724+
func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
1725+
const table = "unique_struct"
1726+
1727+
checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) {
1728+
stmt := &gorm.Statement{DB: DB}
1729+
err := stmt.Parse(model)
1730+
if err != nil {
1731+
t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err)
1732+
}
1733+
_ = stmt.Schema.ParseIndexes()
1734+
field := stmt.Schema.LookUpField(fieldName)
1735+
if field == nil {
1736+
t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName)
1737+
}
1738+
if field.Unique != unique {
1739+
t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique)
1740+
}
1741+
if field.UniqueIndex != uniqueIndex {
1742+
t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex)
1743+
}
1744+
}
1745+
1746+
type ( // not unique
1747+
UniqueStruct1 struct {
1748+
Name string `gorm:"size:10"`
1749+
}
1750+
UniqueStruct2 struct {
1751+
Name string `gorm:"size:20"`
1752+
}
1753+
)
1754+
checkField(&UniqueStruct1{}, "name", false, "")
1755+
checkField(&UniqueStruct2{}, "name", false, "")
1756+
1757+
type ( // unique
1758+
UniqueStruct3 struct {
1759+
Name string `gorm:"size:30;unique"`
1760+
}
1761+
UniqueStruct4 struct {
1762+
Name string `gorm:"size:40;unique"`
1763+
}
1764+
)
1765+
checkField(&UniqueStruct3{}, "name", true, "")
1766+
checkField(&UniqueStruct4{}, "name", true, "")
1767+
1768+
type ( // uniqueIndex
1769+
UniqueStruct5 struct {
1770+
Name string `gorm:"size:50;uniqueIndex"`
1771+
}
1772+
UniqueStruct6 struct {
1773+
Name string `gorm:"size:60;uniqueIndex"`
1774+
}
1775+
UniqueStruct7 struct {
1776+
Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
1777+
NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
1778+
}
1779+
)
1780+
checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name")
1781+
checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name")
1782+
1783+
checkField(&UniqueStruct7{}, "name", false, "")
1784+
checkField(&UniqueStruct7{}, "nick_name", false, "")
1785+
checkField(&UniqueStruct7{}, "nick_name", false, "")
1786+
1787+
type UniqueStruct8 struct { // unique and uniqueIndex
1788+
Name string `gorm:"size:60;unique;index:my_us8_index,unique;"`
1789+
}
1790+
checkField(&UniqueStruct8{}, "name", true, "my_us8_index")
1791+
1792+
type TestCase struct {
1793+
name string
1794+
from, to interface{}
1795+
checkFunc func(t *testing.T)
1796+
}
1797+
1798+
checkColumnType := func(t *testing.T, fieldName string, unique bool) {
1799+
columnTypes, err := DB.Migrator().ColumnTypes(table)
1800+
if err != nil {
1801+
t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err)
1802+
}
1803+
var found gorm.ColumnType
1804+
for _, columnType := range columnTypes {
1805+
if columnType.Name() == fieldName {
1806+
found = columnType
1807+
}
1808+
}
1809+
if found == nil {
1810+
t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName)
1811+
}
1812+
if actualUnique, ok := found.Unique(); !ok || actualUnique != unique {
1813+
t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique)
1814+
}
1815+
}
1816+
1817+
checkIndex := func(t *testing.T, expected []gorm.Index) {
1818+
indexes, err := DB.Migrator().GetIndexes(table)
1819+
if err != nil {
1820+
t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err)
1821+
}
1822+
assert.ElementsMatch(t, expected, indexes)
1823+
}
1824+
1825+
uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
1826+
myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
1827+
mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
1828+
1829+
var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T)
1830+
// UniqueAffectedByUniqueIndex is true
1831+
if DB.Dialector.Name() == "mysql" {
1832+
uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
1833+
checkNotUnique = func(t *testing.T) {
1834+
checkColumnType(t, "name", false)
1835+
checkIndex(t, nil)
1836+
}
1837+
checkUnique = func(t *testing.T) {
1838+
checkColumnType(t, "name", true)
1839+
checkIndex(t, []gorm.Index{uniqueConstraintIndex})
1840+
}
1841+
checkUniqueIndex = func(t *testing.T) {
1842+
checkColumnType(t, "name", true)
1843+
checkIndex(t, []gorm.Index{uniqueIndex})
1844+
}
1845+
checkMyIndex = func(t *testing.T) {
1846+
checkColumnType(t, "name", true)
1847+
checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex})
1848+
}
1849+
checkMulIndex = func(t *testing.T) {
1850+
checkColumnType(t, "name", false)
1851+
checkColumnType(t, "nick_name", false)
1852+
checkIndex(t, []gorm.Index{mulIndex})
1853+
}
1854+
} else {
1855+
checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) }
1856+
checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) }
1857+
checkUniqueIndex = func(t *testing.T) {
1858+
checkColumnType(t, "name", false)
1859+
checkIndex(t, []gorm.Index{uniqueIndex})
1860+
}
1861+
checkMyIndex = func(t *testing.T) {
1862+
checkColumnType(t, "name", true)
1863+
if !DB.Migrator().HasIndex(table, myIndex.Name()) {
1864+
t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name())
1865+
}
1866+
}
1867+
checkMulIndex = func(t *testing.T) {
1868+
checkColumnType(t, "name", false)
1869+
checkColumnType(t, "nick_name", false)
1870+
if !DB.Migrator().HasIndex(table, mulIndex.Name()) {
1871+
t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name())
1872+
}
1873+
}
1874+
}
1875+
1876+
tests := []TestCase{
1877+
{name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
1878+
{name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
1879+
{name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
1880+
{name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
1881+
{name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique},
1882+
{name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique},
1883+
{name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
1884+
{name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
1885+
{name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
1886+
{name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
1887+
{name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex},
1888+
{name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
1889+
{name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex},
1890+
}
1891+
for _, test := range tests {
1892+
t.Run(test.name, func(t *testing.T) {
1893+
if err := DB.Migrator().DropTable(table); err != nil {
1894+
t.Fatalf("failed to drop table, got error: %v", err)
1895+
}
1896+
if err := DB.Table(table).AutoMigrate(test.from); err != nil {
1897+
t.Fatalf("failed to migrate table, got error: %v", err)
1898+
}
1899+
if err := DB.Table(table).AutoMigrate(test.to); err != nil {
1900+
t.Fatalf("failed to migrate table, got error: %v", err)
1901+
}
1902+
test.checkFunc(t)
1903+
})
1904+
}
1905+
1906+
if DB.Dialector.Name() == "mysql" {
1907+
compatibilityTests := []TestCase{
1908+
{name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique},
1909+
{name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique},
1910+
{name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex},
1911+
{name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex},
1912+
}
1913+
for _, test := range compatibilityTests {
1914+
t.Run(test.name, func(t *testing.T) {
1915+
if err := DB.Migrator().DropTable(table); err != nil {
1916+
t.Fatalf("failed to drop table, got error: %v", err)
1917+
}
1918+
if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil {
1919+
t.Fatalf("failed to create table, got error: %v", err)
1920+
}
1921+
if err := DB.Table(table).AutoMigrate(test.to); err != nil {
1922+
t.Fatalf("failed to migrate table, got error: %v", err)
1923+
}
1924+
test.checkFunc(t)
1925+
})
1926+
}
1927+
}
1928+
}

0 commit comments

Comments
 (0)