diff --git a/lc0_main.go b/lc0_main.go index b5a73e0..da5da6d 100644 --- a/lc0_main.go +++ b/lc0_main.go @@ -61,7 +61,7 @@ var ( backopts = flag.String("backend-opts", "", `Options for the lc0 mux. backend. Example: --backend-opts="cudnn(gpu=1)"`) parallel = flag.Int("parallelism", -1, "Number of games to play in parallel (-1 for default)") - cacheDir = flag.String("cache", "", "Directory to use for downloaded networks cache") + cacheDir = flag.String("cache", "", "Directory to use for downloaded files cache (if it exists)") useTestServer = flag.Bool("use-test-server", false, "Set host name to test server.") runId = flag.Uint("run", 0, "Which training run to contribute to (default 0 to let server decide)") keep = flag.Bool("keep", false, "Do not delete old network files") @@ -133,7 +133,7 @@ func getExtraParams() map[string]string { return map[string]string{ "user": *user, "password": *password, - "version": "28", + "version": "29", "token": strconv.Itoa(randId), "train_only": strconv.FormatBool(*trainOnly), "hostname": *localHost, @@ -363,11 +363,11 @@ func (c *cmdWrapper) launch(networkPath string, otherNetPath string, args []stri testedDxNet = networkPath } if *backopts != "" { - // Check against small token blacklist, currently only "random" + // Check against small token blacklist, currently only "mlh", "random" and "recordreplay" tokens := regexp.MustCompile("[,=().0-9]").Split(*backopts, -1) for _, token := range tokens { switch token { - case "random": + case "mlh", "random", "recordreplay": log.Fatalf("Not accepted in --backend-opts: %s", token) } } @@ -815,12 +815,9 @@ func acquireLock(dir string, sha string) (lockfile.Lockfile, error) { return lock, err } -func getNetwork(httpClient *http.Client, sha string, keepTime string) (string, error) { - dir := "client-cache" - if len(*cacheDir) != 0 { - dir = *cacheDir - } else { - userCache := "" +func makeCacheDir(dir string) string { + userCache := *cacheDir + if len(userCache) == 0 { if runtime.GOOS == "linux" { userCache = os.Getenv("XDG_CACHE_HOME") if len(userCache) == 0 { @@ -835,16 +832,22 @@ func getNetwork(httpClient *http.Client, sha string, keepTime string) (string, e userCache = homeDir + "/Library/Caches" } } - - if len(userCache) != 0 { - _, err := os.Stat(userCache) - if err == nil { + } + if len(userCache) != 0 { + _, err := os.Stat(userCache) + if err == nil { + if len(*cacheDir) == 0 { userCache = filepath.Join(userCache, "lc0") - dir = filepath.Join(userCache, dir) } + dir = filepath.Join(userCache, dir) } } os.MkdirAll(dir, os.ModePerm) + return dir +} + +func getNetwork(httpClient *http.Client, sha string, keepTime string) (string, error) { + dir := makeCacheDir("client-cache") if keepTime != inf { err := removeAllExcept(dir, sha, keepTime) if err != nil { @@ -880,9 +883,32 @@ func getNetwork(httpClient *http.Client, sha string, keepTime string) (string, e return checkValidNetwork(dir, sha) } -func getBook(httpClient *http.Client, book_url string) (string, error) { - dir := "books" - os.MkdirAll(dir, os.ModePerm) +func checkValidBook(path string, sha string) (string, error) { + // File already exists? + _, err := os.Stat(path) + if err == nil { + file, _ := os.Open(path) + sum := sha256.New() + _, err := io.Copy(sum, file); + got := fmt.Sprintf("%x", sum.Sum(nil)) + if sha != got { + text := fmt.Sprintf("book sha mismatch want:\n%s\ngot\n%s\n", sha, got) + err = errors.New(text) + } + file.Close() + if err != nil { + fmt.Printf("Deleting invalid book...\n") + os.Remove(path) + return path, err + } else { + return path, nil + } + } + return path, err +} + +func getBook(httpClient *http.Client, book_url string, sha string) (string, error) { + dir := makeCacheDir("books") u, err := url.Parse(book_url) if err != nil { log.Println("Unable to parse book URL") @@ -891,7 +917,7 @@ func getBook(httpClient *http.Client, book_url string) (string, error) { s := strings.Split(u.Path, "/") book_name := s[len(s)-1] path := filepath.Join(dir, book_name) - _, err = os.Stat(path) + _, err = checkValidBook(path, sha) if err == nil { // Book is there, use it. return path, nil @@ -934,7 +960,7 @@ func getBook(httpClient *http.Client, book_url string) (string, error) { // Ensure tmpfile is erased os.Remove(out.Name()) - return path, err + return checkValidBook(path, sha) } func nextGame(httpClient *http.Client, count int) error { @@ -958,10 +984,17 @@ func nextGame(httpClient *http.Client, count int) error { log.Printf("serverParams: %s", serverParams) if nextGame.BookUrl != "" { - _, err := getBook(&http.Client{}, nextGame.BookUrl) + book, err := getBook(&http.Client{}, nextGame.BookUrl, nextGame.BookSha) if err != nil { return err } + // Replace the book file with the correct path + for i := range serverParams { + if strings.HasPrefix(serverParams[i], "--openings-pgn=") { + serverParams[i] = "--openings-pgn=" + book + break + } + } } if nextGame.Type == "match" { diff --git a/src/client/client_http.go b/src/client/client_http.go index 39b7723..b2bee1d 100644 --- a/src/client/client_http.go +++ b/src/client/client_http.go @@ -87,6 +87,7 @@ type NextGameResponse struct { MatchGameId uint KeepTime string BookUrl string + BookSha string } func NextGame(httpClient *http.Client, hostname string, params map[string]string) (NextGameResponse, error) {