Skip to content

Commit

Permalink
Merge pull request #664 from mooskagh/server-config
Browse files Browse the repository at this point in the history
Move hardcoded things out of server code into config file.
  • Loading branch information
glinscott committed Jun 2, 2018
1 parent e14e85a commit 75ac923
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 53 deletions.
6 changes: 3 additions & 3 deletions go/src/server/cmd/bootstrap/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
)

func main() {
db.Init(true)
db.SetupDB()
db.CreateTrainingRun()
db.Init()
defer db.Close()
db.SetupDB()
db.CreateTrainingRun("Initial run just for test")
}
43 changes: 43 additions & 0 deletions go/src/server/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package config

import (
"encoding/json"
"io/ioutil"
)

// Config is a Server config.
var Config struct {
Database struct {
Host string
User string
Dbname string
Password string
}
Clients struct {
MinClientVersion uint64
MinEngineVersion string
}
URLs struct {
OnNewNetwork []string
NetworkLocation string
}
Matches struct {
Games int
Parameters []interface{}
Threshold float64
}
WebServer struct {
Address string
}
}

func init() {
content, err := ioutil.ReadFile("serverconfig.json")
if err != nil {
panic(err)
}
err = json.Unmarshal(content, &Config)
if err != nil {
panic(err)
}
}
27 changes: 18 additions & 9 deletions go/src/server/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@ import (
"log"

"github.com/jinzhu/gorm"
// Importing to support postgre database.
_ "github.com/jinzhu/gorm/dialects/postgres"
"server/config"
)

var db *gorm.DB
var err error

func Init(prod bool) {
dbname := "gorm_test"
if prod {
dbname = "gorm"
}
conn := fmt.Sprintf("host=localhost user=gorm dbname=%s sslmode=disable password=gorm", dbname)
// Init initializes database.
func Init() {
conn := fmt.Sprintf(
"host=%s user=%s dbname=%s sslmode=disable password=%s",
config.Config.Database.Host,
config.Config.Database.User,
config.Config.Database.Dbname,
config.Config.Database.Password,
)
db, err = gorm.Open("postgres", conn)
if err != nil {
log.Fatal("Unable to connect to DB", err)
}
}

// SetupDB setups DB.
func SetupDB() {
db.AutoMigrate(&User{})
db.AutoMigrate(&TrainingRun{})
Expand All @@ -32,19 +38,22 @@ func SetupDB() {
db.AutoMigrate(&TrainingGame{})
}

// CreateTrainingRun creates training run
func CreateTrainingRun(description string) *TrainingRun {
training_run := TrainingRun{Description: description}
err := db.Create(&training_run).Error
trainingRun := TrainingRun{Description: description}
err := db.Create(&trainingRun).Error
if err != nil {
log.Fatal(err)
}
return &training_run
return &trainingRun
}

// GetDB returns current database object
func GetDB() *gorm.DB {
return db
}

// Close closes database
func Close() {
db.Close()
}
100 changes: 59 additions & 41 deletions go/src/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"compress/gzip"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -14,6 +15,7 @@ import (
"os"
"os/exec"
"path/filepath"
"server/config"
"server/db"
"strconv"
"strings"
Expand Down Expand Up @@ -49,9 +51,9 @@ func checkUser(c *gin.Context) (*db.User, uint64, error) {
if err != nil {
return nil, 0, errors.New("Invalid version")
}
if version < 10 {
if version < config.Config.Clients.MinClientVersion {
log.Printf("Rejecting old game from %s, version %d\n", user.Username, version)
return nil, 0, errors.New("\n\n\n\n\nYou must upgrade to a newer version!!\n\n\n\n\n")
return nil, 0, errors.New("you must upgrade to a newer version")
}

return user, version, nil
Expand All @@ -65,22 +67,22 @@ func nextGame(c *gin.Context) {
return
}

training_run := db.TrainingRun{
trainingRun := db.TrainingRun{
Active: true,
}
// TODO(gary): Only really supports one training run right now...
err = db.GetDB().Where(&training_run).First(&training_run).Error
err = db.GetDB().Where(&trainingRun).First(&trainingRun).Error
if err != nil {
log.Println(err)
c.String(http.StatusBadRequest, "Invalid training run")
return
}

network := db.Network{}
err = db.GetDB().Where("id = ?", training_run.BestNetworkID).First(&network).Error
err = db.GetDB().Where("id = ?", trainingRun.BestNetworkID).First(&network).Error
if err != nil {
log.Println(err)
c.String(500, "Internal error")
c.String(500, "Internal error 1")
return
}

Expand All @@ -89,28 +91,28 @@ func nextGame(c *gin.Context) {
err = db.GetDB().Preload("Candidate").Where("done=false").Limit(1).Find(&match).Error
if err != nil {
log.Println(err)
c.String(500, "Internal error")
c.String(500, "Internal error 2")
return
}
if len(match) > 0 {
// Return this match
match_game := db.MatchGame{
matchGame := db.MatchGame{
UserID: user.ID,
MatchID: match[0].ID,
}
err = db.GetDB().Create(&match_game).Error
err = db.GetDB().Create(&matchGame).Error
// Note, this could cause an imbalance of white/black games for a particular match,
// but it's good enough for now.
flip := (match_game.ID & 1) == 1
db.GetDB().Model(&match_game).Update("flip", flip)
flip := (matchGame.ID & 1) == 1
db.GetDB().Model(&matchGame).Update("flip", flip)
if err != nil {
log.Println(err)
c.String(500, "Internal error")
c.String(500, "Internal error 3")
return
}
result := gin.H{
"type": "match",
"matchGameId": match_game.ID,
"matchGameId": matchGame.ID,
"sha": network.Sha,
"candidateSha": match[0].Candidate.Sha,
"params": match[0].Parameters,
Expand All @@ -123,18 +125,18 @@ func nextGame(c *gin.Context) {

result := gin.H{
"type": "train",
"trainingId": training_run.ID,
"networkId": training_run.BestNetworkID,
"trainingId": trainingRun.ID,
"networkId": trainingRun.BestNetworkID,
"sha": network.Sha,
"params": training_run.TrainParameters,
"params": trainingRun.TrainParameters,
}
c.JSON(http.StatusOK, result)
}

// Computes SHA256 of gzip compressed file
func computeSha(http_file *multipart.FileHeader) (string, error) {
func computeSha(httpFile *multipart.FileHeader) (string, error) {
h := sha256.New()
file, err := http_file.Open()
file, err := httpFile.Open()
if err != nil {
return "", err
}
Expand All @@ -155,13 +157,13 @@ func computeSha(http_file *multipart.FileHeader) (string, error) {
return sha, nil
}

func getTrainingRun(training_id uint) (*db.TrainingRun, error) {
var training_run db.TrainingRun
err := db.GetDB().Where("id = ?", training_id).First(&training_run).Error
func getTrainingRun(trainingID uint) (*db.TrainingRun, error) {
var trainingRun db.TrainingRun
err := db.GetDB().Where("id = ?", trainingID).First(&trainingRun).Error
if err != nil {
return nil, err
}
return &training_run, nil
return &trainingRun, nil
}

func uploadNetwork(c *gin.Context) {
Expand Down Expand Up @@ -198,8 +200,8 @@ func uploadNetwork(c *gin.Context) {

// Create new network
// TODO(gary): Just hardcoding this for now.
var training_run_id uint = 1
network.TrainingRunID = training_run_id
var trainingRunID uint = 1
network.TrainingRunID = trainingRunID
layers, err := strconv.ParseInt(c.PostForm("layers"), 10, 32)
network.Layers = int(layers)
filters, err := strconv.ParseInt(c.PostForm("filters"), 10, 32)
Expand Down Expand Up @@ -227,29 +229,45 @@ func uploadNetwork(c *gin.Context) {
}

// TODO(gary): Make this more generic - upload to s3 for now
cmd := exec.Command("aws", "s3", "cp", network.Path, "s3://lczero/networks/")
err = cmd.Run()
cmdParams := config.Config.URLs.OnNewNetwork
if len(cmdParams) > 0 {
for i := range cmdParams {
if cmdParams[i] == "%NETWORK_PATH%" {
cmdParams[i] = network.Path
}
}

cmd := exec.Command(cmdParams[0], cmdParams[1:]...)
err = cmd.Run()
if err != nil {
log.Println(err.Error())
c.String(500, "Uploading to s3")
return
}
}

// Create a match to see if this network is better
trainingRun, err := getTrainingRun(trainingRunID)
if err != nil {
log.Println(err.Error())
c.String(500, "Uploading to s3")
log.Println(err)
c.String(500, "Internal error")
return
}

// Create a match to see if this network is better
training_run, err := getTrainingRun(training_run_id)
params, err := json.Marshal(config.Config.Matches.Parameters)
if err != nil {
log.Println(err)
c.String(500, "Internal error")
return
}

match := db.Match{
TrainingRunID: training_run_id,
TrainingRunID: trainingRunID,
CandidateID: network.ID,
CurrentBestID: training_run.BestNetworkID,
CurrentBestID: trainingRun.BestNetworkID,
Done: false,
GameCap: 400,
Parameters: `["--tempdecay=10"]`,
GameCap: config.Config.Matches.Games,
Parameters: string(params[:]),
}
if c.DefaultPostForm("testonly", "0") == "1" {
match.TestOnly = true
Expand All @@ -269,7 +287,7 @@ func checkEngineVersion(engineVersion string) bool {
if err != nil {
return false
}
target, err := version.NewVersion("0.10")
target, err := version.NewVersion(config.Config.Clients.MinEngineVersion)
if err != nil {
log.Println("Invalid comparison version, rejecting all clients!!!")
return false
Expand Down Expand Up @@ -380,7 +398,7 @@ func uploadGame(c *gin.Context) {
func getNetwork(c *gin.Context) {
// lczero.org/cached/ is behind the cloudflare CDN. Redirect to there to ensure
// we hit the CDN.
c.Redirect(http.StatusMovedPermanently, "http://lczero.org/cached/network/sha/"+c.Query("sha"))
c.Redirect(http.StatusMovedPermanently, config.Config.URLs.NetworkLocation+c.Query("sha"))
}

func cachedGetNetwork(c *gin.Context) {
Expand Down Expand Up @@ -437,7 +455,7 @@ func checkMatchFinished(match_id uint) error {
}
// Update to our new best network
// TODO(SPRT)
passed := calcElo(match.Wins, match.Losses, match.Draws) > -150.0
passed := calcElo(match.Wins, match.Losses, match.Draws) > config.Config.Matches.Threshold
err = db.GetDB().Model(&match).Update("passed", passed).Error
if err != nil {
return err
Expand Down Expand Up @@ -1123,7 +1141,7 @@ func viewTrainingData(c *gin.Context) {
}
for game_id < int(id) {
files = append([]gin.H{
gin.H{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/games%d.tar.gz", game_id)},
{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/games%d.tar.gz", game_id)},
}, files...)
game_id += 10000
}
Expand All @@ -1132,7 +1150,7 @@ func viewTrainingData(c *gin.Context) {
pgnId := 9000000
for pgnId < int(id-500000) {
pgnFiles = append([]gin.H{
gin.H{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/run1/pgn%d.tar.gz", pgnId)},
{"url": fmt.Sprintf("https://s3.amazonaws.com/lczero/training/run1/pgn%d.tar.gz", pgnId)},
}, pgnFiles...)
pgnId += 100000
}
Expand Down Expand Up @@ -1187,10 +1205,10 @@ func setupRouter() *gin.Engine {
}

func main() {
db.Init(true)
db.Init()
db.SetupDB()
defer db.Close()

router := setupRouter()
router.Run(":8080")
router.Run(config.Config.WebServer.Address)
}
Loading

0 comments on commit 75ac923

Please sign in to comment.