diff --git a/.gitignore b/.gitignore index daf913b..aaa0a73 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ _testmain.go *.exe *.test *.prof +.idea/ diff --git a/cmd/db2struct/main.go b/cmd/db2struct/main.go index a6b85c6..b101442 100644 --- a/cmd/db2struct/main.go +++ b/cmd/db2struct/main.go @@ -84,7 +84,7 @@ func main() { return } - columnDataTypes, err := db2struct.GetColumnsFromMysqlTable(*mariadbUser, *mariadbPassword, mariadbHost, *mariadbPort, *mariadbDatabase, *mariadbTable) + columnDataTypes, columnsSorted, err := db2struct.GetColumnsFromMysqlTable(*mariadbUser, *mariadbPassword, mariadbHost, *mariadbPort, *mariadbDatabase, *mariadbTable) if err != nil { fmt.Println("Error in selecting column data information from mysql information schema") @@ -100,7 +100,7 @@ func main() { *packageName = "newpackage" } // Generate struct string based on columnDataTypes - struc, err := db2struct.Generate(*columnDataTypes, *mariadbTable, *structName, *packageName, *jsonAnnotation, *gormAnnotation, *gureguTypes) + struc, err := db2struct.Generate(*columnDataTypes, columnsSorted, *mariadbTable, *structName, *packageName, *jsonAnnotation, *gormAnnotation, *gureguTypes) if err != nil { fmt.Println("Error in creating struct from json: " + err.Error()) diff --git a/utils.go b/utils.go index 88f341e..fbcd9a2 100644 --- a/utils.go +++ b/utils.go @@ -81,10 +81,10 @@ var Debug = false // Generate Given a Column map with datatypes and a name structName, // attempts to generate a struct definition -func Generate(columnTypes map[string]map[string]string, tableName string, structName string, pkgName string, jsonAnnotation bool, gormAnnotation bool, gureguTypes bool) ([]byte, error) { +func Generate(columnTypes map[string]map[string]string, columnsSorted []string, tableName string, structName string, pkgName string, jsonAnnotation bool, gormAnnotation bool, gureguTypes bool) ([]byte, error) { var dbTypes string - dbTypes = generateMysqlTypes(columnTypes, 0, jsonAnnotation, gormAnnotation, gureguTypes) - src := fmt.Sprintf("package %s\ntype %s %s}", + dbTypes = generateMysqlTypes(columnTypes, columnsSorted, 0, jsonAnnotation, gormAnnotation, gureguTypes) + src := fmt.Sprintf("package %s\ntype %s %s\n}", pkgName, structName, dbTypes) diff --git a/utils_mysql.go b/utils_mysql.go index 261d70e..d254c0e 100644 --- a/utils_mysql.go +++ b/utils_mysql.go @@ -4,13 +4,12 @@ import ( "database/sql" "errors" "fmt" - "sort" "strconv" "strings" ) // GetColumnsFromMysqlTable Select column details from information schema and return map of map -func GetColumnsFromMysqlTable(mariadbUser string, mariadbPassword string, mariadbHost string, mariadbPort int, mariadbDatabase string, mariadbTable string) (*map[string]map[string]string, error) { +func GetColumnsFromMysqlTable(mariadbUser string, mariadbPassword string, mariadbHost string, mariadbPort int, mariadbDatabase string, mariadbTable string) (*map[string]map[string]string, []string, error) { var err error var db *sql.DB @@ -24,13 +23,15 @@ func GetColumnsFromMysqlTable(mariadbUser string, mariadbPassword string, mariad // Check for error in db, note this does not check connectivity but does check uri if err != nil { fmt.Println("Error opening mysql db: " + err.Error()) - return nil, err + return nil, nil, err } + columnNamesSorted := []string{} + // Store colum as map of maps columnDataTypes := make(map[string]map[string]string) // Select columnd data from INFORMATION_SCHEMA - columnDataTypeQuery := "SELECT COLUMN_NAME, COLUMN_KEY, DATA_TYPE, IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ? AND table_name = ?" + columnDataTypeQuery := "SELECT COLUMN_NAME, COLUMN_KEY, DATA_TYPE, IS_NULLABLE, COLUMN_COMMENT FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ? AND table_name = ?" if Debug { fmt.Println("running: " + columnDataTypeQuery) @@ -40,12 +41,12 @@ func GetColumnsFromMysqlTable(mariadbUser string, mariadbPassword string, mariad if err != nil { fmt.Println("Error selecting from db: " + err.Error()) - return nil, err + return nil, nil, err } if rows != nil { defer rows.Close() } else { - return nil, errors.New("No results returned for table") + return nil, nil, errors.New("No results returned for table") } for rows.Next() { @@ -53,25 +54,21 @@ func GetColumnsFromMysqlTable(mariadbUser string, mariadbPassword string, mariad var columnKey string var dataType string var nullable string - rows.Scan(&column, &columnKey, &dataType, &nullable) + var comment string + rows.Scan(&column, &columnKey, &dataType, &nullable, &comment) - columnDataTypes[column] = map[string]string{"value": dataType, "nullable": nullable, "primary": columnKey} + columnDataTypes[column] = map[string]string{"value": dataType, "nullable": nullable, "primary": columnKey, "comment": comment} + columnNamesSorted = append(columnNamesSorted, column) } - return &columnDataTypes, err + return &columnDataTypes, columnNamesSorted, err } // Generate go struct entries for a map[string]interface{} structure -func generateMysqlTypes(obj map[string]map[string]string, depth int, jsonAnnotation bool, gormAnnotation bool, gureguTypes bool) string { +func generateMysqlTypes(obj map[string]map[string]string, columnsSorted []string, depth int, jsonAnnotation bool, gormAnnotation bool, gureguTypes bool) string { structure := "struct {" - keys := make([]string, 0, len(obj)) - for key := range obj { - keys = append(keys, key) - } - sort.Strings(keys) - - for _, key := range keys { + for _, key := range columnsSorted { mysqlType := obj[key] nullable := false if mysqlType["nullable"] == "YES" { @@ -97,16 +94,14 @@ func generateMysqlTypes(obj map[string]map[string]string, depth int, jsonAnnotat if jsonAnnotation == true { annotations = append(annotations, fmt.Sprintf("json:\"%s\"", key)) } - if len(annotations) > 0 { - structure += fmt.Sprintf("\n%s %s `%s`", - fieldName, - valueType, - strings.Join(annotations, " ")) + if len(annotations) > 0 { + // add colulmn comment + comment:=mysqlType["comment"] + structure += fmt.Sprintf("\n%s %s `%s` //%s", fieldName, valueType, strings.Join(annotations, " "), comment) + //structure += fmt.Sprintf("\n%s %s `%s`", fieldName, valueType, strings.Join(annotations, " ")) } else { - structure += fmt.Sprintf("\n%s %s", - fieldName, - valueType) + structure += fmt.Sprintf("\n%s %s",fieldName,valueType) } } return structure