From 75ac923d83305836da5e8822709912972fa60233 Mon Sep 17 00:00:00 2001 From: Gary Linscott Date: Fri, 1 Jun 2018 20:57:31 -0700 Subject: [PATCH] Merge pull request #664 from mooskagh/server-config Move hardcoded things out of server code into config file. --- go/src/server/cmd/bootstrap/main.go | 6 +- go/src/server/config/config.go | 43 ++++++++++++ go/src/server/db/db.go | 27 +++++--- go/src/server/main.go | 100 ++++++++++++++++------------ go/src/server/serverconfig.json | 24 +++++++ 5 files changed, 147 insertions(+), 53 deletions(-) create mode 100644 go/src/server/config/config.go create mode 100644 go/src/server/serverconfig.json diff --git a/go/src/server/cmd/bootstrap/main.go b/go/src/server/cmd/bootstrap/main.go index 88ebbddd3..5836612b0 100644 --- a/go/src/server/cmd/bootstrap/main.go +++ b/go/src/server/cmd/bootstrap/main.go @@ -5,8 +5,8 @@ import ( ) func main() { - db.Init(true) - db.SetupDB() - db.CreateTrainingRun() + db.Init() defer db.Close() + db.SetupDB() + db.CreateTrainingRun("Initial run just for test") } diff --git a/go/src/server/config/config.go b/go/src/server/config/config.go new file mode 100644 index 000000000..74a2a8f4b --- /dev/null +++ b/go/src/server/config/config.go @@ -0,0 +1,43 @@ +package config + +import ( + "encoding/json" + "io/ioutil" +) + +// Config is a Server config. +var Config struct { + Database struct { + Host string + User string + Dbname string + Password string + } + Clients struct { + MinClientVersion uint64 + MinEngineVersion string + } + URLs struct { + OnNewNetwork []string + NetworkLocation string + } + Matches struct { + Games int + Parameters []interface{} + Threshold float64 + } + WebServer struct { + Address string + } +} + +func init() { + content, err := ioutil.ReadFile("serverconfig.json") + if err != nil { + panic(err) + } + err = json.Unmarshal(content, &Config) + if err != nil { + panic(err) + } +} diff --git a/go/src/server/db/db.go b/go/src/server/db/db.go index d3f3d81e2..306f03a41 100644 --- a/go/src/server/db/db.go +++ b/go/src/server/db/db.go @@ -5,24 +5,30 @@ import ( "log" "github.com/jinzhu/gorm" + // Importing to support postgre database. _ "github.com/jinzhu/gorm/dialects/postgres" + "server/config" ) var db *gorm.DB var err error -func Init(prod bool) { - dbname := "gorm_test" - if prod { - dbname = "gorm" - } - conn := fmt.Sprintf("host=localhost user=gorm dbname=%s sslmode=disable password=gorm", dbname) +// Init initializes database. +func Init() { + conn := fmt.Sprintf( + "host=%s user=%s dbname=%s sslmode=disable password=%s", + config.Config.Database.Host, + config.Config.Database.User, + config.Config.Database.Dbname, + config.Config.Database.Password, + ) db, err = gorm.Open("postgres", conn) if err != nil { log.Fatal("Unable to connect to DB", err) } } +// SetupDB setups DB. func SetupDB() { db.AutoMigrate(&User{}) db.AutoMigrate(&TrainingRun{}) @@ -32,19 +38,22 @@ func SetupDB() { db.AutoMigrate(&TrainingGame{}) } +// CreateTrainingRun creates training run func CreateTrainingRun(description string) *TrainingRun { - training_run := TrainingRun{Description: description} - err := db.Create(&training_run).Error + trainingRun := TrainingRun{Description: description} + err := db.Create(&trainingRun).Error if err != nil { log.Fatal(err) } - return &training_run + return &trainingRun } +// GetDB returns current database object func GetDB() *gorm.DB { return db } +// Close closes database func Close() { db.Close() } diff --git a/go/src/server/main.go b/go/src/server/main.go index 937eabc4e..ad5d666e0 100644 --- a/go/src/server/main.go +++ b/go/src/server/main.go @@ -3,6 +3,7 @@ package main import ( "compress/gzip" "crypto/sha256" + "encoding/json" "errors" "fmt" "io" @@ -14,6 +15,7 @@ import ( "os" "os/exec" "path/filepath" + "server/config" "server/db" "strconv" "strings" @@ -49,9 +51,9 @@ func checkUser(c *gin.Context) (*db.User, uint64, error) { if err != nil { return nil, 0, errors.New("Invalid version") } - if version < 10 { + if version < config.Config.Clients.MinClientVersion { log.Printf("Rejecting old game from %s, version %d\n", user.Username, version) - return nil, 0, errors.New("\n\n\n\n\nYou must upgrade to a newer version!!\n\n\n\n\n") + return nil, 0, errors.New("you must upgrade to a newer version") } return user, version, nil @@ -65,11 +67,11 @@ func nextGame(c *gin.Context) { return } - training_run := db.TrainingRun{ + trainingRun := db.TrainingRun{ Active: true, } // TODO(gary): Only really supports one training run right now... - err = db.GetDB().Where(&training_run).First(&training_run).Error + err = db.GetDB().Where(&trainingRun).First(&trainingRun).Error if err != nil { log.Println(err) c.String(http.StatusBadRequest, "Invalid training run") @@ -77,10 +79,10 @@ func nextGame(c *gin.Context) { } network := db.Network{} - err = db.GetDB().Where("id = ?", training_run.BestNetworkID).First(&network).Error + err = db.GetDB().Where("id = ?", trainingRun.BestNetworkID).First(&network).Error if err != nil { log.Println(err) - c.String(500, "Internal error") + c.String(500, "Internal error 1") return } @@ -89,28 +91,28 @@ func nextGame(c *gin.Context) { err = db.GetDB().Preload("Candidate").Where("done=false").Limit(1).Find(&match).Error if err != nil { log.Println(err) - c.String(500, "Internal error") + c.String(500, "Internal error 2") return } if len(match) > 0 { // Return this match - match_game := db.MatchGame{ + matchGame := db.MatchGame{ UserID: user.ID, MatchID: match[0].ID, } - err = db.GetDB().Create(&match_game).Error + err = db.GetDB().Create(&matchGame).Error // Note, this could cause an imbalance of white/black games for a particular match, // but it's good enough for now. - flip := (match_game.ID & 1) == 1 - db.GetDB().Model(&match_game).Update("flip", flip) + flip := (matchGame.ID & 1) == 1 + db.GetDB().Model(&matchGame).Update("flip", flip) if err != nil { log.Println(err) - c.String(500, "Internal error") + c.String(500, "Internal error 3") return } result := gin.H{ "type": "match", - "matchGameId": match_game.ID, + "matchGameId": matchGame.ID, "sha": network.Sha, "candidateSha": match[0].Candidate.Sha, "params": match[0].Parameters, @@ -123,18 +125,18 @@ func nextGame(c *gin.Context) { result := gin.H{ "type": "train", - "trainingId": training_run.ID, - "networkId": training_run.BestNetworkID, + "trainingId": trainingRun.ID, + "networkId": trainingRun.BestNetworkID, "sha": network.Sha, - "params": training_run.TrainParameters, + "params": trainingRun.TrainParameters, } c.JSON(http.StatusOK, result) } // Computes SHA256 of gzip compressed file -func computeSha(http_file *multipart.FileHeader) (string, error) { +func computeSha(httpFile *multipart.FileHeader) (string, error) { h := sha256.New() - file, err := http_file.Open() + file, err := httpFile.Open() if err != nil { return "", err } @@ -155,13 +157,13 @@ func computeSha(http_file *multipart.FileHeader) (string, error) { return sha, nil } -func getTrainingRun(training_id uint) (*db.TrainingRun, error) { - var training_run db.TrainingRun - err := db.GetDB().Where("id = ?", training_id).First(&training_run).Error +func getTrainingRun(trainingID uint) (*db.TrainingRun, error) { + var trainingRun db.TrainingRun + err := db.GetDB().Where("id = ?", trainingID).First(&trainingRun).Error if err != nil { return nil, err } - return &training_run, nil + return &trainingRun, nil } func uploadNetwork(c *gin.Context) { @@ -198,8 +200,8 @@ func uploadNetwork(c *gin.Context) { // Create new network // TODO(gary): Just hardcoding this for now. - var training_run_id uint = 1 - network.TrainingRunID = training_run_id + var trainingRunID uint = 1 + network.TrainingRunID = trainingRunID layers, err := strconv.ParseInt(c.PostForm("layers"), 10, 32) network.Layers = int(layers) filters, err := strconv.ParseInt(c.PostForm("filters"), 10, 32) @@ -227,16 +229,32 @@ func uploadNetwork(c *gin.Context) { } // TODO(gary): Make this more generic - upload to s3 for now - cmd := exec.Command("aws", "s3", "cp", network.Path, "s3://lczero/networks/") - err = cmd.Run() + cmdParams := config.Config.URLs.OnNewNetwork + if len(cmdParams) > 0 { + for i := range cmdParams { + if cmdParams[i] == "%NETWORK_PATH%" { + cmdParams[i] = network.Path + } + } + + cmd := exec.Command(cmdParams[0], cmdParams[1:]...) + err = cmd.Run() + if err != nil { + log.Println(err.Error()) + c.String(500, "Uploading to s3") + return + } + } + + // Create a match to see if this network is better + trainingRun, err := getTrainingRun(trainingRunID) if err != nil { - log.Println(err.Error()) - c.String(500, "Uploading to s3") + log.Println(err) + c.String(500, "Internal error") return } - // Create a match to see if this network is better - training_run, err := getTrainingRun(training_run_id) + params, err := json.Marshal(config.Config.Matches.Parameters) if err != nil { log.Println(err) c.String(500, "Internal error") @@ -244,12 +262,12 @@ func uploadNetwork(c *gin.Context) { } match := db.Match{ - TrainingRunID: training_run_id, + TrainingRunID: trainingRunID, CandidateID: network.ID, - CurrentBestID: training_run.BestNetworkID, + CurrentBestID: trainingRun.BestNetworkID, Done: false, - GameCap: 400, - Parameters: `["--tempdecay=10"]`, + GameCap: config.Config.Matches.Games, + Parameters: string(params[:]), } if c.DefaultPostForm("testonly", "0") == "1" { match.TestOnly = true @@ -269,7 +287,7 @@ func checkEngineVersion(engineVersion string) bool { if err != nil { return false } - target, err := version.NewVersion("0.10") + target, err := version.NewVersion(config.Config.Clients.MinEngineVersion) if err != nil { log.Println("Invalid comparison version, rejecting all clients!!!") return false @@ -380,7 +398,7 @@ func uploadGame(c *gin.Context) { func getNetwork(c *gin.Context) { // lczero.org/cached/ is behind the cloudflare CDN. Redirect to there to ensure // we hit the CDN. - c.Redirect(http.StatusMovedPermanently, "http://lczero.org/cached/network/sha/"+c.Query("sha")) + c.Redirect(http.StatusMovedPermanently, config.Config.URLs.NetworkLocation+c.Query("sha")) } func cachedGetNetwork(c *gin.Context) { @@ -437,7 +455,7 @@ func checkMatchFinished(match_id uint) error { } // Update to our new best network // TODO(SPRT) - passed := calcElo(match.Wins, match.Losses, match.Draws) > -150.0 + passed := calcElo(match.Wins, match.Losses, match.Draws) > config.Config.Matches.Threshold err = db.GetDB().Model(&match).Update("passed", passed).Error if err != nil { return err @@ -1123,7 +1141,7 @@ func viewTrainingData(c *gin.Context) { } for game_id < int(id) { files = append([]gin.H{ - gin.H{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/games%d.tar.gz", game_id)}, + {"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/games%d.tar.gz", game_id)}, }, files...) game_id += 10000 } @@ -1132,7 +1150,7 @@ func viewTrainingData(c *gin.Context) { pgnId := 9000000 for pgnId < int(id-500000) { pgnFiles = append([]gin.H{ - gin.H{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/run1/pgn%d.tar.gz", pgnId)}, + {"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/run1/pgn%d.tar.gz", pgnId)}, }, pgnFiles...) pgnId += 100000 } @@ -1187,10 +1205,10 @@ func setupRouter() *gin.Engine { } func main() { - db.Init(true) + db.Init() db.SetupDB() defer db.Close() router := setupRouter() - router.Run(":8080") + router.Run(config.Config.WebServer.Address) } diff --git a/go/src/server/serverconfig.json b/go/src/server/serverconfig.json new file mode 100644 index 000000000..28ca82619 --- /dev/null +++ b/go/src/server/serverconfig.json @@ -0,0 +1,24 @@ +{ + "database": { + "host": "localhost", + "user": "gorm", + "dbname": "gorm", + "password": "gorm" + }, + "clients": { + "minClientVersion": 10, + "minEngineVersion": "v0.10" + }, + "urls": { + "onNewNetwork": ["aws", "s3", "cp", "%NETWORK_PATH%", "s3://lczero/networks/"], + "networkLocation": "/cached/network/sha/" + }, + "matches": { + "games": 400, + "parameters": ["--tempdecay=10"], + "threshold": -150.0 + }, + "webserver": { + "address": ":8080" + } +}