diff --git a/src/chess/callbacks.h b/src/chess/callbacks.h index 79ebf2ff8b..2beae8a65b 100644 --- a/src/chess/callbacks.h +++ b/src/chess/callbacks.h @@ -69,6 +69,8 @@ struct ThinkingInfo { int hashfull = -1; // Win in centipawns. optional score; + // Distance to mate. + optional mate; // Number of successful TB probes (not the same as playouts ending in TB hit). int tb_hits = -1; // Best line found. Moves are from perspective of white player. diff --git a/src/chess/position.h b/src/chess/position.h index 922727e64a..eeac0204c1 100644 --- a/src/chess/position.h +++ b/src/chess/position.h @@ -82,6 +82,13 @@ class Position { }; enum class GameResult { UNDECIDED, WHITE_WON, DRAW, BLACK_WON }; +enum class CertaintyTrigger { NONE, TB_HIT, TWO_FOLD, TERMINAL, NORMAL }; + +struct CertaintyResult { + public: + GameResult gameresult; + CertaintyTrigger trigger; +}; class PositionHistory { public: diff --git a/src/chess/uciloop.cc b/src/chess/uciloop.cc index 9d00185c01..1e6f09bdea 100644 --- a/src/chess/uciloop.cc +++ b/src/chess/uciloop.cc @@ -247,7 +247,13 @@ void UciLoop::SendInfo(const std::vector& infos) { if (info.seldepth >= 0) res += " seldepth " + std::to_string(info.seldepth); if (info.time >= 0) res += " time " + std::to_string(info.time); if (info.nodes >= 0) res += " nodes " + std::to_string(info.nodes); - if (info.score) res += " score cp " + std::to_string(*info.score); + + // If mate display mate, otherwise if score display score. + if (info.mate) { + res += " score mate " + std::to_string(*info.mate); + } else if (info.score) { + res += " score cp " + std::to_string(*info.score); + } if (info.hashfull >= 0) res += " hashfull " + std::to_string(info.hashfull); if (info.nps >= 0) res += " nps " + std::to_string(info.nps); if (info.tb_hits >= 0) res += " tbhits " + std::to_string(info.tb_hits); diff --git a/src/engine.cc b/src/engine.cc index eea66e57d9..0eaa2eef6f 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -349,6 +349,7 @@ void EngineController::Go(const GoParams& params) { if (info.multipv <= 1) { ponder_info = info; if (ponder_info.score) ponder_info.score = -*ponder_info.score; + if (ponder_info.mate) ponder_info.mate = -*ponder_info.mate; if (ponder_info.depth > 1) ponder_info.depth--; if (ponder_info.seldepth > 1) ponder_info.seldepth--; ponder_info.pv.clear(); diff --git a/src/mcts/node.cc b/src/mcts/node.cc index a47fca1cc5..f626f8084e 100644 --- a/src/mcts/node.cc +++ b/src/mcts/node.cc @@ -28,6 +28,7 @@ #include "mcts/node.h" #include +#include #include #include #include @@ -160,9 +161,58 @@ float Edge::GetP() const { return ret; } +void Edge::MakeTerminal(GameResult result) { + certainty_state_ |= kTerminalMask | kCertainMask | kUpperBound | kLowerBound; + certainty_state_ &= kGameResultClear; + if (result == GameResult::WHITE_WON) { + certainty_state_ |= kGameResultWin; + } else if (result == GameResult::BLACK_WON) { + certainty_state_ |= kGameResultLoss; + } +} + +void Edge::MakeCertain(CertaintyResult certaintyresult) { + certainty_state_ |= kCertainMask | kUpperBound | kLowerBound; + certainty_state_ &= kGameResultClear; + if (certaintyresult.gameresult == GameResult::WHITE_WON) { + certainty_state_ |= kGameResultWin; + } else if (certaintyresult.gameresult == GameResult::BLACK_WON) { + certainty_state_ |= kGameResultLoss; + } + if (certaintyresult.trigger == CertaintyTrigger::TB_HIT) certainty_state_ |= kTBHit; + if (certaintyresult.trigger == CertaintyTrigger::TWO_FOLD) certainty_state_ |= kTwoFold; +} + +void Edge::MakeCertain(int q, CertaintyTrigger trigger) { + certainty_state_ |= kCertainMask | kUpperBound | kLowerBound; + certainty_state_ &= kGameResultClear; + if (q == 1) { + certainty_state_ |= kGameResultWin; + } else if (q == -1) { + certainty_state_ |= kGameResultLoss; + } + if (trigger == CertaintyTrigger::TB_HIT) certainty_state_ |= kTBHit; + if (trigger == CertaintyTrigger::TWO_FOLD) certainty_state_ |= kTwoFold; +} +void Edge::SetEQ(int eq) { + certainty_state_ &= kGameResultClear; + if (eq == 1) { + certainty_state_ |= kGameResultWin; + } else if (eq == -1) { + certainty_state_ |= kGameResultLoss; + } +} + +int Edge::GetEQ() const { + if (certainty_state_ & kGameResultLoss) return -1; + if (certainty_state_ & kGameResultWin) return 1; + return 0; +} + std::string Edge::DebugString() const { std::ostringstream oss; - oss << "Move: " << move_.as_string() << " p_: " << p_ << " GetP: " << GetP(); + oss << "Move: " << move_.as_string() << " p_: " << p_ << " GetP: " << GetP() + << " Certainty:" << std::bitset<8>(certainty_state_); return oss.str(); } @@ -197,39 +247,51 @@ void Node::CreateEdges(const MoveList& moves) { Node::ConstIterator Node::Edges() const { return {edges_, &child_}; } Node::Iterator Node::Edges() { return {edges_, &child_}; } +void Node::RecomputeNfromChildren() { + if (n_ > 1) { + uint32_t visits = 1; + for (const auto& child : Edges()) visits += child.GetN(); + n_ = visits; + } + assert(n_in_flight_ == 0); +} + float Node::GetVisitedPolicy() const { return visited_policy_; } +float Node::GetQ() const { + // Currently all certain edges have a corresponding node (PR700) and + // that nodes q_ is set correctly. If we later allow edges to become + // certain without creating the node (PR487 through look-ahead-search), + // we need to revisit this getter or adapt search. + // if (parent_) { + // auto edge = parent_->GetEdgeToNode(this); + // if (edge->IsCertain()) return (float)edge->GetEQ(); + // } + return q_; +} + Edge* Node::GetEdgeToNode(const Node* node) const { assert(node->parent_ == this); assert(node->index_ < edges_.size()); return &edges_[node->index_]; } -Edge* Node::GetOwnEdge() const { return GetParent()->GetEdgeToNode(this); } +Edge* Node::GetOwnEdge() const { + if (GetParent()) { + return GetParent()->GetEdgeToNode(this); + } else + return nullptr; +} std::string Node::DebugString() const { std::ostringstream oss; - oss << " Term:" << is_terminal_ << " This:" << this << " Parent:" << parent_ - << " Index:" << index_ << " Child:" << child_.get() - << " Sibling:" << sibling_.get() << " Q:" << q_ << " N:" << n_ - << " N_:" << n_in_flight_ << " Edges:" << edges_.size(); + oss << " This:" << this << " Parent:" << parent_ << " Index:" << index_ + << " Child:" << child_.get() << " Sibling:" << sibling_.get() + << " Q:" << q_ << " N:" << n_ << " N_:" << n_in_flight_ + << " Edges:" << edges_.size(); return oss.str(); } -void Node::MakeTerminal(GameResult result) { - is_terminal_ = true; - if (result == GameResult::DRAW) { - q_ = 0.0f; - d_ = 1.0f; - } else if (result == GameResult::WHITE_WON) { - q_ = 1.0f; - d_ = 0.0f; - } else if (result == GameResult::BLACK_WON) { - q_ = -1.0f; - d_ = 0.0f; - } -} - bool Node::TryStartScoreUpdate() { if (n_ == 0 && n_in_flight_ > 0) return false; ++n_in_flight_; @@ -388,6 +450,13 @@ void NodeTree::MakeMove(Move move) { current_head_->ReleaseChildrenExceptOne(new_head); current_head_ = new_head ? new_head : current_head_->CreateSingleChildNode(move); + // If certain and no children, reset node (so that n_ = 0). + if (current_head_->IsCertain() && !current_head_->HasChildren()) + TrimTreeAtHead(); + // Clear certainty flag but keep bounds. + if (current_head_->GetParent()) + current_head_->GetOwnEdge()->ClearCertaintyState(); + current_head_->RecomputeNfromChildren(); history_.Append(move); } @@ -430,10 +499,19 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen, // previously searched position, which means that the current_head_ might // retain old n_ and q_ (etc) data, even though its old children were // previously trimmed; we need to reset current_head_ in that case. - // Also, if the current_head_ is terminal, reset that as well to allow forced - // analysis of WDL hits, or possibly 3 fold or 50 move "draws", etc. - if (!seen_old_head || current_head_->IsTerminal()) TrimTreeAtHead(); - + // Also, if the current_head_ is certain and has no children, reset that + // as well to allow forced analysis of WDL hits, or possibly 2 or 3 fold + // or 50 move "draws", etc. + if (!seen_old_head || + (current_head_->IsCertain() && !current_head_->HasChildren())) + TrimTreeAtHead(); + // Certainty Propagation: No need to trim the head for certain nodes with + // children (these became certain through backpropagation), just resetting + // certainty state except bounds, and recomputing N suffices. TrimTreeAtHead + // sets n_ to 0 this remains 0 after RecomputeNfromChildren. + if (current_head_->GetParent()) + current_head_->GetOwnEdge()->ClearCertaintyState(); + current_head_->RecomputeNfromChildren(); return seen_old_head; } diff --git a/src/mcts/node.h b/src/mcts/node.h index 02c3d9f5a0..87ad7ad63d 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -86,6 +86,72 @@ class Edge { float GetP() const; void SetP(float val); + void MakeTerminal(GameResult result); + + // Sets flags for certainty, trigger of certainty and result by GameResult. + void MakeCertain(CertaintyResult certaintyresult); + + // Sets flags for certainty, trigger of certainty and result (by Q). + void MakeCertain(int q, CertaintyTrigger trigger); + + // Sets edge-Q: win = 1; draw = 0; loss = -1. + void SetEQ(int eq); + // Clears Certainty but keeps bounds. + void ClearCertaintyState() { + if (certainty_state_ & kCertainMask) { + certainty_state_ = 0; + } else { + certainty_state_ &= kClearKeepBounds; + } + }; + // Sets (U)pper and (L)ower bounds. + void UBound(int eq) { + certainty_state_ |= kUpperBound; + SetEQ(eq); + }; + void LBound(int eq) { + certainty_state_ |= kLowerBound; + SetEQ(eq); + }; + + // Returns true if only upper bounded. + bool IsOnlyUBounded() { + return (certainty_state_ & kUpperBound) && + !(certainty_state_ & kLowerBound); + }; + + // Returns true if only lower bounded. + bool IsOnlyLBounded() { + return !(certainty_state_ & kUpperBound) && + (certainty_state_ & kLowerBound); + }; + bool IsTerminal() const { return certainty_state_ & kTerminalMask; }; + bool IsCertain() const { return certainty_state_ & kCertainMask; }; + bool IsCertainWin() const { + return ((certainty_state_ & kCertainMask) && + (certainty_state_ & kGameResultWin)); + }; + bool IsCertainLoss() const { + return ((certainty_state_ & kCertainMask) && + (certainty_state_ & kGameResultLoss)); + }; + bool IsCertainDraw() const { + return ((certainty_state_ & kCertainMask) && + !(certainty_state_ & ~kGameResultClear)); + }; + + // Returns true if certainty proof based on a TB hit. + bool IsPropagatedTBHit() const { return certainty_state_ & kTBHit; }; + + // Query bounds. + bool IsUBounded() const { return certainty_state_ & kUpperBound; }; + bool IsLBounded() const { return certainty_state_ & kLowerBound; }; + + // Return all stats flags. + uint8_t GetCertaintyState() const { return certainty_state_; }; + // Returns the edges Q + int GetEQ() const; + // Debug information about the edge. std::string DebugString() const; @@ -101,6 +167,28 @@ class Edge { // network; compressed to a 16 bit format (5 bits exp, 11 bits significand). uint16_t p_ = 0; + // Certainty Propagation attaches a number of flags to each edge. + // kTerminalMask -> edge is terminal. + // kCertainMask -> edge is certain. + // kUpperBound -> edge is upper bounded. + // kLowerBound -> edge is lower bounded. + // kTBHit -> edge is certain because of a TB hit. + // kTwoFold -> edge is certain because of a two-fold. + // kGameResultWin -> edge is a proven win. + // kGameResultLoss -> edge is a proven loss. + uint8_t certainty_state_ = 0; + enum Masks : uint8_t { + kTerminalMask = 0b00000001, + kCertainMask = 0b00000010, + kUpperBound = 0b00000100, + kLowerBound = 0b00001000, + kTBHit = 0b00010000, + kTwoFold = 0b00100000, + kGameResultWin = 0b01000000, + kGameResultLoss = 0b10000000, + kGameResultClear = 0b00111111, + kClearKeepBounds = 0b00001100, + }; friend class EdgeList; }; @@ -144,6 +232,11 @@ class Node { // Returns whether a node has children. bool HasChildren() const { return edges_; } + // Recalculate n_ from real children visits. + // This is needed if a node was proved to be certain in a prior + // search and later gets to be root of search. + void RecomputeNfromChildren(); + // Returns sum of policy priors which have had at least one playout. float GetVisitedPolicy() const; uint32_t GetN() const { return n_; } @@ -151,17 +244,74 @@ class Node { uint32_t GetChildrenVisits() const { return n_ > 0 ? n_ - 1 : 0; } // Returns n = n_if_flight. int GetNStarted() const { return n_ + n_in_flight_; } - // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 - // for terminal nodes. - float GetQ() const { return q_; } + // Returns node eval, i.e. average subtree V for non-certain node and -1/0/1 + // for certain nodes. + float GetQ() const; float GetD() const { return d_; } // Returns whether the node is known to be draw/lose/win. - bool IsTerminal() const { return is_terminal_; } + bool IsTerminal() const { + return GetOwnEdge() ? GetOwnEdge()->IsTerminal() : false; + } + bool IsCertain() const { + return GetOwnEdge() ? GetOwnEdge()->IsCertain() : false; + } + + // Sets bounds. + void UBound(int eq) const { + if (GetOwnEdge()) GetOwnEdge()->UBound(eq); + } + void LBound(int eq) const { + if (GetOwnEdge()) GetOwnEdge()->LBound(eq); + } + // Queries bounds. + + bool IsBounded() const { + return GetOwnEdge() + ? GetOwnEdge()->IsLBounded() || GetOwnEdge()->IsUBounded() + : false; + } + bool IsOnlyUBounded() const { + return GetOwnEdge() ? GetOwnEdge()->IsOnlyUBounded() : false; + } uint16_t GetNumEdges() const { return edges_.size(); } - // Makes the node terminal and sets it's score. - void MakeTerminal(GameResult result); + // Makes the node terminal or certain and sets its score. + void MakeTerminal(GameResult result) { + if (GetOwnEdge()) GetOwnEdge()->MakeTerminal(result); + if (result == GameResult::DRAW) { + q_ = 0.0f; + d_ = 1.0f; + } else if (result == GameResult::WHITE_WON) { + q_ = 1.0f; + d_ = 0.0f; + } else { + q_ = -1.0f; + d_ = 0.0f; + } + } + void MakeCertain(CertaintyResult certaintyresult) { + if (GetOwnEdge()) GetOwnEdge()->MakeCertain(certaintyresult); + if (certaintyresult.gameresult == GameResult::DRAW) { + q_ = 0.0f; + d_ = 1.0f; + } else if (certaintyresult.gameresult == GameResult::WHITE_WON) { + q_ = 1.0f; + d_ = 0.0f; + } else { + q_ = -1.0f; + d_ = 0.0f; + } + } + void MakeCertain(int q, CertaintyTrigger trigger) { + if (GetOwnEdge()) GetOwnEdge()->MakeCertain(q, trigger); + q_ = q; + if (q == 0) { + d_ = 1.0f; + } else { + d_ = 0.0f; + } + } // If this node is not in the process of being expanded by another thread // (which can happen only if n==0 and n-in-flight==1), mark the node as @@ -201,10 +351,6 @@ class Node { return best_child_cache_in_flight_limit_ - n_in_flight_; } - // Calculates the full depth if new depth is larger, updates it, returns - // in depth parameter, and returns true if it was indeed updated. - bool UpdateFullDepth(uint16_t* depth); - V4TrainingData GetV4TrainingData(GameResult result, const PositionHistory& history, FillEmptyHistory fill_empty_history, @@ -251,7 +397,8 @@ class Node { EdgeList edges_; // 8 byte fields. - // Pointer to a parent node. nullptr for the root. + // Pointer to a parent node. nullptr for the root of tree, + // Note: root of tree might not be search->root_node_. Node* parent_ = nullptr; // Pointer to a first child. nullptr for a leaf node. std::unique_ptr child_; @@ -283,13 +430,9 @@ class Node { uint32_t best_child_cache_in_flight_limit_ = 0; // 2 byte fields. - // Index of this node is parent's edge list. + // Index of this node in parent's edge list. uint16_t index_; - // 1 byte fields. - // Whether or not this node end game (with a winning of either sides or draw). - bool is_terminal_ = false; - // TODO(mooskagh) Unfriend NodeTree. friend class NodeTree; friend class Edge_Iterator; @@ -340,19 +483,35 @@ class EdgeAndNode { float GetD() const { return (node_ && node_->GetN() > 0) ? node_->GetD() : 0.0f; } - // N-related getters, from Node (if exists). + + // Gets the edge's Q, if edge is certain this + // is the proven game result (-1, 0, +1). + int GetEQ() const { return edge_->GetEQ(); } + + // N-related getters, from node (if exists). uint32_t GetN() const { return node_ ? node_->GetN() : 0; } int GetNStarted() const { return node_ ? node_->GetNStarted() : 0; } uint32_t GetNInFlight() const { return node_ ? node_->GetNInFlight() : 0; } - // Whether the node is known to be terminal. - bool IsTerminal() const { return node_ ? node_->IsTerminal() : false; } - // Edge related getters. float GetP() const { return edge_->GetP(); } Move GetMove(bool flip = false) const { return edge_ ? edge_->GetMove(flip) : Move(); } + bool IsTerminal() const { return edge_->IsTerminal(); } + bool IsCertain() const { return edge_->IsCertain(); } + bool IsCertainWin() const { return edge_->IsCertainWin(); } + + // Queries bounds. + bool IsUBounded() const { return edge_->IsUBounded(); } + bool IsLBounded() const { return edge_->IsLBounded(); } + + // Sets bounds. + void UBound(int eq) { edge_->UBound(eq); } + void LBound(int eq) { edge_->LBound(eq); } + + // Queries if edge's certainty is based on a TB hit. + bool IsPropagatedTBHit() const { return edge_->IsPropagatedTBHit(); } // Returns U = numerator * p / N. // Passed numerator is expected to be equal to (cpuct * sqrt(N[parent])). diff --git a/src/mcts/params.cc b/src/mcts/params.cc index 0190584c91..02e3cee7b9 100644 --- a/src/mcts/params.cc +++ b/src/mcts/params.cc @@ -172,6 +172,14 @@ const OptionId SearchParams::kKLDGainAverageInterval{ "kldgain-average-interval", "KLDGainAverageInterval", "Used to decide how frequently to evaluate the average KLDGainPerNode to " "check the MinimumKLDGainPerNode, if specified."}; +const OptionId SearchParams::kCertaintyPropagationId{ + "certainty-propagation", "CertaintyPropagation", + "Propagates certain scores more efficiently in the search tree, " + "proves and displays mates."}; +const OptionId SearchParams::kTwoFoldDrawScoringId{ + "two-fold-draw-scoring", "TwoFoldDrawScoring", + "Scores two-folds as draws (0.00) in search to use visits more " + "efficiently. Recommended in conjunction with certainty propagation."}; void SearchParams::Populate(OptionsParser* options) { // Here the uci optimized defaults" are set. @@ -208,6 +216,8 @@ void SearchParams::Populate(OptionsParser* options) { options->Add(kHistoryFillId, history_fill_opt) = "fen_only"; options->Add(kKLDGainAverageInterval, 1, 10000000) = 100; options->Add(kMinimumKLDGainPerNode, 0.0f, 1.0f) = 0.0f; + options->Add(kCertaintyPropagationId) = false; + options->Add(kTwoFoldDrawScoringId) = false; } SearchParams::SearchParams(const OptionsDict& options) @@ -226,6 +236,8 @@ SearchParams::SearchParams(const OptionsDict& options) kMaxCollisionEvents(options.Get(kMaxCollisionEventsId.GetId())), kMaxCollisionVisits(options.Get(kMaxCollisionVisitsId.GetId())), kOutOfOrderEval(options.Get(kOutOfOrderEvalId.GetId())), + kCertaintyPropagation(options.Get(kCertaintyPropagationId.GetId())), + kTwoFoldDrawScoring(options.Get(kTwoFoldDrawScoringId.GetId())), kSyzygyFastPlay(options.Get(kSyzygyFastPlayId.GetId())), kHistoryFill( EncodeHistoryFill(options.Get(kHistoryFillId.GetId()))), diff --git a/src/mcts/params.h b/src/mcts/params.h index c49357fab5..fbc674cfa1 100644 --- a/src/mcts/params.h +++ b/src/mcts/params.h @@ -89,6 +89,8 @@ class SearchParams { return options_.Get(kScoreTypeId.GetId()); } FillEmptyHistory GetHistoryFill() const { return kHistoryFill; } + bool GetCertaintyPropagation() const { return kCertaintyPropagation; } + bool GetTwoFoldDrawScoring() const { return kTwoFoldDrawScoring; } int GetKLDGainAverageInterval() const { return options_.Get(kKLDGainAverageInterval.GetId()); } @@ -123,6 +125,8 @@ class SearchParams { static const OptionId kMultiPvId; static const OptionId kScoreTypeId; static const OptionId kHistoryFillId; + static const OptionId kCertaintyPropagationId; + static const OptionId kTwoFoldDrawScoringId; static const OptionId kMinimumKLDGainPerNode; static const OptionId kKLDGainAverageInterval; @@ -147,6 +151,8 @@ class SearchParams { const int kMaxCollisionEvents; const int kMaxCollisionVisits; const bool kOutOfOrderEval; + const bool kCertaintyPropagation; + const bool kTwoFoldDrawScoring; const bool kSyzygyFastPlay; const FillEmptyHistory kHistoryFill; const int kMiniBatchSize; diff --git a/src/mcts/search.cc b/src/mcts/search.cc index 81ffbd347f..7cea56942f 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -28,6 +28,7 @@ #include "mcts/search.h" #include +#include #include #include #include @@ -118,18 +119,18 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) { common_info.nps = common_info.time ? (total_playouts_ * 1000 / common_info.time) : 0; common_info.tb_hits = tb_hits_.load(std::memory_order_acquire); - int multipv = 0; for (const auto& edge : edges) { + float score = edge.GetQ(-root_node_->GetQ()); ++multipv; uci_infos.emplace_back(common_info); auto& uci_info = uci_infos.back(); if (score_type == "centipawn") { - uci_info.score = 290.680623072 * tan(1.548090806 * edge.GetQ(0)); + uci_info.score = 290.680623072 * tan(1.548090806 * score); } else if (score_type == "win_percentage") { - uci_info.score = edge.GetQ(0) * 5000 + 5000; + uci_info.score = score * 5000 + 5000; } else if (score_type == "Q") { - uci_info.score = edge.GetQ(0) * 10000; + uci_info.score = score * 10000; } if (params_.GetMultiPv() > 1) uci_info.multipv = multipv; bool flip = played_history_.IsBlackToMove(); @@ -138,6 +139,19 @@ void Search::SendUciInfo() REQUIRES(nodes_mutex_) { uci_info.pv.push_back(iter.GetMove(flip)); if (!iter.node()) break; // Last edge was dangling, cannot continue. } + + // Mate display if certain win (or loss) with distance to mate set to + // length of pv (average mate). + // If win is based on propagated TB bit, length of mate is + // adjusted by +1000; If root filtered TB moves are draw display 0. + if (params_.GetCertaintyPropagation()) { + if (edge.IsCertain() && edge.GetEQ() != 0) { + uci_info.mate = edge.GetEQ() * ((uci_info.pv.size() + 1) / 2 + + (edge.IsPropagatedTBHit() ? 1000 : 0)); + } else if (root_syzygy_rank_ == 1) { + uci_info.score = 0; + } + } } if (!uci_infos.empty()) last_outputted_uci_info_ = uci_infos.front(); @@ -239,8 +253,8 @@ std::vector Search::GetVerboseStats(Node* node, oss << "(Q: " << std::setw(8) << std::setprecision(5) << edge.GetQ(fpu) << ") "; - oss << "(D: " << std::setw(6) << std::setprecision(3) - << edge.GetD() << ") "; + oss << "(D: " << std::setw(6) << std::setprecision(3) << edge.GetD() + << ") "; oss << "(U: " << std::setw(6) << std::setprecision(5) << edge.GetU(U_coeff) << ") "; @@ -250,8 +264,8 @@ std::vector Search::GetVerboseStats(Node* node, oss << "(V: "; optional v; - if (edge.IsTerminal()) { - v = edge.node()->GetQ(); + if (edge.IsCertain()) { + v = edge.edge()->GetEQ(); } else { NNCacheLock nneval = GetCachedNNEval(edge.node()); if (nneval) v = -nneval->q; @@ -263,7 +277,8 @@ std::vector Search::GetVerboseStats(Node* node, } oss << ") "; - if (edge.IsTerminal()) oss << "(T) "; + oss << " C:" << std::bitset<8>(edge.edge()->GetCertaintyState()); + infos.emplace_back(oss.str()); } return infos; @@ -485,22 +500,29 @@ std::int64_t Search::GetTotalPlayouts() const { return total_playouts_; } -bool Search::PopulateRootMoveLimit(MoveList* root_moves) const { +int Search::PopulateRootMoveLimit(MoveList* root_moves) const { // Search moves overrides tablebase. if (!limits_.searchmoves.empty()) { *root_moves = limits_.searchmoves; - return false; + return 0; } + + // Syzygy root_probe returns best_rank for proper eval if + // moves are syzygy root filtered. auto board = played_history_.Last().GetBoard(); if (!syzygy_tb_ || !board.castlings().no_legal_castle() || (board.ours() | board.theirs()).count() > syzygy_tb_->max_cardinality()) { - return false; + return 0; } - return syzygy_tb_->root_probe(played_history_.Last(), - params_.GetSyzygyFastPlay() || - played_history_.DidRepeatSinceLastZeroingMove(), - root_moves) || - syzygy_tb_->root_probe_wdl(played_history_.Last(), root_moves); + + int best_rank = syzygy_tb_->root_probe( + played_history_.Last(), + params_.GetSyzygyFastPlay() || + played_history_.DidRepeatSinceLastZeroingMove(), + root_moves); + if (!best_rank) + best_rank = syzygy_tb_->root_probe_wdl(played_history_.Last(), root_moves); + return best_rank; } // Computes the best move, maybe with temperature (according to the settings). @@ -540,11 +562,15 @@ std::vector Search::GetBestChildrenNoTemperature(Node* parent, PopulateRootMoveLimit(&root_limit); } // Best child is selected using the following criteria: + // with Certainty Propagation: + // * Prefer terminal wins, then certain wins. + // * Avoid losses, but prefer certain losses over terminal losses. + // Otherwise: // * Largest number of playouts. // * If two nodes have equal number: // * If that number is 0, the one with larger prior wins. // * If that number is larger than 0, the one with larger eval wins. - using El = std::tuple; + using El = std::tuple; std::vector edges; for (auto edge : parent->Edges()) { if (parent == root_node_ && !root_limit.empty() && @@ -552,19 +578,41 @@ std::vector Search::GetBestChildrenNoTemperature(Node* parent, root_limit.end()) { continue; } - edges.emplace_back(edge.GetN(), edge.GetQ(0), edge.GetP(), edge); + edges.emplace_back((params_.GetCertaintyPropagation()) + ? edge.edge()->GetEQ() * (edge.IsTerminal() + 1) + : 0, + edge.GetN(), edge.GetQ(0), edge.GetP(), edge); + } + // Ensure that certain draws have at least as many virtual visits as the + // first move with Q<=0 (these visits are used during final sort). + // The result is that they're always preferred over moves with Q<0. + // Certain draws with more visits are left as is, so that they are + // preferred over all less-explored moves, even if those moves have Q>0; + // in this respect behaviour is identical to normal leela. + if (params_.GetCertaintyPropagation()) { + std::partial_sort(edges.begin(), edges.end(), edges.end(), + std::greater()); + // largest N with Q >= 0 + uint64_t largest_N = 0; + for (auto it = edges.begin(); it != edges.end(); ++it) { + if (std::get<2>(*it) <= 0.0f && largest_N == 0) + largest_N = std::get<1>(*it); + if (std::get<4>(*it).edge()->IsCertainDraw() && largest_N > 0) + std::get<1>(*it) = largest_N; + } } + // Final sort pass. auto middle = (static_cast(edges.size()) > count) ? edges.begin() + count : edges.end(); std::partial_sort(edges.begin(), middle, edges.end(), std::greater()); std::vector res; std::transform(edges.begin(), middle, std::back_inserter(res), - [](const El& x) { return std::get<3>(x); }); + [](const El& x) { return std::get<4>(x); }); return res; } -// Returns a child with most visits. +// Returns best child. EdgeAndNode Search::GetBestChildNoTemperature(Node* parent) const { auto res = GetBestChildrenNoTemperature(parent, 1); return res.empty() ? EdgeAndNode() : res.front(); @@ -763,8 +811,10 @@ void SearchWorker::InitializeIteration( if (!root_move_filter_populated_) { root_move_filter_populated_ = true; - if (search_->PopulateRootMoveLimit(&root_move_filter_)) { + int best_rank = search_->PopulateRootMoveLimit(&root_move_filter_); + if (best_rank) { search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel); + search_->root_syzygy_rank_ = best_rank; } } } @@ -808,8 +858,8 @@ void SearchWorker::GatherMinibatch() { // Node was never visited, extend it. ExtendNode(node); - // Only send non-terminal nodes to a neural network. - if (!node->IsTerminal()) { + // Only send uncertain nodes to a neural network. + if (!node->IsCertain()) { picked_node.nn_queried = true; picked_node.is_cache_hit = AddNodeToComputation(node, true); } @@ -900,13 +950,12 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( } return NodeToProcess::Collision(node, depth, collision_limit); } - // Either terminal or unexamined leaf node -- the end of this playout. - if (!node->HasChildren()) { - if (node->IsTerminal()) { - return NodeToProcess::TerminalHit(node, depth, 1); - } else { - return NodeToProcess::Extension(node, depth); - } + // Either terminal/certain or unexamined leaf node -- the end of this + // playout. + if (node->IsCertain()) { + return NodeToProcess::TerminalHit(node, depth, 1); + } else if (!node->HasChildren()) { + return NodeToProcess::Extension(node, depth); } Node* possible_shortcut_child = node->GetCachedBestChild(); if (possible_shortcut_child) { @@ -930,6 +979,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( float second_best = std::numeric_limits::lowest(); int possible_moves = 0; const float fpu = GetFpu(params_, node, is_root_node); + bool parent_upperbounded = node->IsOnlyUBounded(); for (auto child : node->Edges()) { if (is_root_node) { // If there's no chance to catch up to the current best node with @@ -941,6 +991,19 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( search_->remaining_playouts_ < best_node_n - child.GetN()) { continue; } + // If play certain win and don't search other + // moves at root. If search limit infinite continue searching other + // moves. + if (params_.GetCertaintyPropagation() && child.edge()->IsCertainWin()) { + if (!search_->limits_.infinite) { + best_edge = child; + possible_moves = 1; + break; + } else if (search_->current_best_edge_ == child && + possible_moves > 0) { + continue; + } + } // If root move filter exists, make sure move is in the list. if (!root_move_filter_.empty() && std::find(root_move_filter_.begin(), root_move_filter_.end(), @@ -950,6 +1013,19 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( ++possible_moves; } float Q = child.GetQ(fpu); + + // Certainty Propagation. Avoid suboptimal childs. + if (params_.GetCertaintyPropagation()) { + // Prefers lower bounded childs over drawing children. + if (child.edge()->IsOnlyLBounded() && child.GetQ(0) <= 0.0f) Q = 0.01f; + // Prefers drawing children over upper bounded childs. + if (child.edge()->IsOnlyUBounded() && child.GetQ(0) >= 0.0f) Q = -0.01f; + // Penalize exploring suboptimal childs throughout the tree. + if (parent_upperbounded) { + if (child.edge()->IsOnlyUBounded()) Q -= child.GetN() * 0.1f; + } + } + const float score = child.GetU(puct_mult) + Q; if (score > best) { second_best = best; @@ -986,60 +1062,46 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( } } -void SearchWorker::ExtendNode(Node* node) { - // Initialize position sequence with pre-move position. - history_.Trim(search_->played_history_.GetLength()); - std::vector to_add; - // Could instead reserve one more than the difference between history_.size() - // and history_.capacity(). - to_add.reserve(60); - Node* cur = node; - while (cur != search_->root_node_) { - Node* prev = cur->GetParent(); - to_add.push_back(prev->GetEdgeToNode(cur)->GetMove()); - cur = prev; - } - for (int i = to_add.size() - 1; i >= 0; i--) { - history_.Append(to_add[i]); - } - - // We don't need the mutex because other threads will see that N=0 and - // N-in-flight=1 and will not touch this node. - const auto& board = history_.Last().GetBoard(); - auto legal_moves = board.GenerateLegalMoves(); - +CertaintyResult SearchWorker::EvalPosition(const Node* node, + const MoveList& legal_moves, + const ChessBoard& board) { + CertaintyResult certaintyresult = { GameResult::UNDECIDED, + CertaintyTrigger::NONE }; // Check whether it's a draw/lose by position. Importantly, we must check // these before doing the by-rule checks below. if (legal_moves.empty()) { // Could be a checkmate or a stalemate if (board.IsUnderCheck()) { - node->MakeTerminal(GameResult::WHITE_WON); + certaintyresult = {GameResult::WHITE_WON, CertaintyTrigger::TERMINAL}; } else { - node->MakeTerminal(GameResult::DRAW); + certaintyresult = {GameResult::DRAW, CertaintyTrigger::TERMINAL}; } - return; + return certaintyresult; } // We can shortcircuit these draws-by-rule only if they aren't root; // if they are root, then thinking about them is the point. if (node != search_->root_node_) { if (!board.HasMatingMaterial()) { - node->MakeTerminal(GameResult::DRAW); - return; + return certaintyresult = {GameResult::DRAW, CertaintyTrigger::TERMINAL}; } if (history_.Last().GetNoCaptureNoPawnPly() >= 100) { - node->MakeTerminal(GameResult::DRAW); - return; + return certaintyresult = {GameResult::DRAW, CertaintyTrigger::TERMINAL}; } if (history_.Last().GetRepetitions() >= 2) { - node->MakeTerminal(GameResult::DRAW); - return; + return certaintyresult = {GameResult::DRAW, CertaintyTrigger::TERMINAL}; + } + + if ((history_.Last().GetRepetitions() >= 1) && + params_.GetTwoFoldDrawScoring()) { + return certaintyresult = {GameResult::DRAW, CertaintyTrigger::TWO_FOLD}; } // Neither by-position or by-rule termination, but maybe it's a TB position. - if (search_->syzygy_tb_ && board.castlings().no_legal_castle() && + if (!search_->root_syzygy_rank_ && search_->syzygy_tb_ && + board.castlings().no_legal_castle() && history_.Last().GetNoCaptureNoPawnPly() == 0 && (board.ours() | board.theirs()).count() <= search_->syzygy_tb_->max_cardinality()) { @@ -1050,18 +1112,50 @@ void SearchWorker::ExtendNode(Node* node) { if (state != FAIL) { // If the colors seem backwards, check the checkmate check above. if (wdl == WDL_WIN) { - node->MakeTerminal(GameResult::BLACK_WON); + certaintyresult = { GameResult::BLACK_WON, CertaintyTrigger::TB_HIT }; } else if (wdl == WDL_LOSS) { - node->MakeTerminal(GameResult::WHITE_WON); + certaintyresult = { GameResult::WHITE_WON, CertaintyTrigger::TB_HIT }; } else { // Cursed wins and blessed losses count as draws. - node->MakeTerminal(GameResult::DRAW); + certaintyresult = { GameResult::DRAW, CertaintyTrigger::NORMAL }; } search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel); - return; } } } + return certaintyresult; +} +void SearchWorker::ExtendNode(Node* node) { + // Initialize position sequence with pre-move position. + history_.Trim(search_->played_history_.GetLength()); + std::vector to_add; + // Could instead reserve one more than the difference between history_.size() + // and history_.capacity(). + to_add.reserve(60); + Node* cur = node; + while (cur != search_->root_node_) { + Node* prev = cur->GetParent(); + to_add.push_back(prev->GetEdgeToNode(cur)->GetMove()); + cur = prev; + } + for (int i = to_add.size() - 1; i >= 0; i--) { + history_.Append(to_add[i]); + } + + // We don't need the mutex because other threads will see that N=0 and + // N-in-flight=1 and will not touch this node. + const auto& board = history_.Last().GetBoard(); + auto legal_moves = board.GenerateLegalMoves(); + CertaintyResult certaintyresult = + EvalPosition(node, legal_moves, board); + + if (certaintyresult.trigger != CertaintyTrigger::NONE) { + if (certaintyresult.trigger == CertaintyTrigger::TERMINAL) + node->MakeTerminal(certaintyresult.gameresult); + else + node->MakeCertain(certaintyresult); + return; + } // Add legal moves as edges of this node. node->CreateEdges(legal_moves); } @@ -1106,6 +1200,7 @@ void SearchWorker::MaybePrefetchIntoCache() { // TODO(mooskagh) Remove prefetch into cache if node collisions work well. // If there are requests to NN, but the batch is not full, try to prefetch // nodes which are likely useful in future. + // TODO(Videodr0me) Maybe use bounds here to more efficiently select nodes. if (search_->stop_.load(std::memory_order_acquire)) return; if (computation_->GetCacheMisses() > 0 && computation_->GetCacheMisses() < params_.GetMaxPrefetchBatch()) { @@ -1136,8 +1231,8 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) { assert(node); // n = 0 and n_in_flight_ > 0, that means the node is being extended. if (node->GetN() == 0) return 0; - // The node is terminal; don't prefetch it. - if (node->IsTerminal()) return 0; + // The node is certain; don't prefetch it. + if (node->IsCertain()) return 0; // Populate all subnodes and their scores. typedef std::pair ScoredEdge; @@ -1216,8 +1311,8 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process, int idx_in_computation) { Node* node = node_to_process->node; if (!node_to_process->nn_queried) { - // Terminal nodes don't involve the neural NetworkComputation, nor do - // they require any further processing after value retrieval. + // Terminal or certain nodes don't involve the neural NetworkComputation, + // nor do they require any further processing after value retrieval. node_to_process->v = node->GetQ(); node_to_process->d = node->GetD(); return; @@ -1278,16 +1373,74 @@ void SearchWorker::DoBackupUpdateSingleNode( // Backup V value up to a root. After 1 visit, V = Q. float v = node_to_process.v; float d = node_to_process.d; + bool origin_bounded = node->IsBounded(); for (Node* n = node; n != search_->root_node_->GetParent(); n = n->GetParent()) { + // Certainty Propagation: + // If update could affect bounds (origin_bounded), + // check all childs, and update bounds/certainty. + float prev_q = -100.0f; + float prev_d = -100.0f; + if (params_.GetCertaintyPropagation() && n != node && (origin_bounded) && + !n->IsCertain()) { + bool based_on_propagated_tbhit = false; + int lower_bound = -1; + int upper_bound = -1; + for (auto iter : n->Edges()) { + if (iter.IsLBounded() && iter.GetEQ() > lower_bound) + lower_bound = iter.GetEQ(); + if (iter.IsUBounded() && iter.GetEQ() > upper_bound) + upper_bound = iter.GetEQ(); + // Only checking !UBounded so that lower bounded + // edges, also get the correct upper_bound. + if (!iter.IsUBounded()) upper_bound = 1; + if (lower_bound == upper_bound && lower_bound == 1) { + based_on_propagated_tbhit = iter.IsPropagatedTBHit(); + break; + } + based_on_propagated_tbhit |= iter.IsPropagatedTBHit(); + } + // Exact scores are certain and propagate certainty. + // Inexact scores propagate their bounds. + if (lower_bound == upper_bound) { + if (n != search_->root_node_) { + prev_q = n->GetQ(); + prev_d = n->GetD(); + n->MakeCertain(-lower_bound, based_on_propagated_tbhit + ? CertaintyTrigger::TB_HIT + : CertaintyTrigger::NORMAL); + v = (float)-lower_bound; + } + } else { + if (lower_bound > -1) n->UBound(-lower_bound); + if (upper_bound < 1) n->LBound(-upper_bound); + } + } + + // Certainty propagation: reduce error by keeping score in proven bounds. + if (params_.GetCertaintyPropagation() && n->GetParent() && + !n->IsCertain()) { + if (n->GetOwnEdge()->IsUBounded() && v > 0.0f) v = 0.00f; + if (n->GetOwnEdge()->IsLBounded() && v < 0.0f) v = 0.00f; + } + n->FinalizeScoreUpdate(v, d, node_to_process.multivisit); + + // Certainty propagation: adjust Qs along the path as if all visits already + // had propagated the certain result. + if (params_.GetCertaintyPropagation() && (prev_q != -100.0f) && + (prev_q != v) && n->IsCertain()) { + v = v + (v - prev_q) * (n->GetN() - 1); + d = d + (d - prev_d) * (n->GetN() - 1); + } + // Q will be flipped for opponent. v = -v; - // Update the stats. - // Best move. + // Update best move if new N > best N or + // if the node is a certain child of root. if (n->GetParent() == search_->root_node_ && - search_->current_best_edge_.GetN() <= n->GetN()) { + (search_->current_best_edge_.GetN() <= n->GetN() || n->IsCertain())) { search_->current_best_edge_ = search_->GetBestChildNoTemperature(search_->root_node_); } diff --git a/src/mcts/search.h b/src/mcts/search.h index 628d08d54b..cd5778b9f7 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -118,15 +118,19 @@ class Search { void SendUciInfo(); // Requires nodes_mutex_ to be held. // Sets stop to true and notifies watchdog thread. void FireStopInternal(); - void SendMovesStats() const; // Function which runs in a separate thread and watches for time and // uci `stop` command; void WatchdogThread(); // Populates the given list with allowed root moves. - // Returns true if the population came from tablebase. - bool PopulateRootMoveLimit(MoveList* root_moves) const; + // Returns best_rank != 0 if the population came from tablebase. + // WDL and DTZ ranks of +1000 are certain wins, -1000 certain losses, + // 1 is a certain draw. For more info on in-between ranks + // (cursed wins, blessed losses, adjusted by dtz) see syzygy probe code. + // Currently only rank = 1 is used to correct score display when + // moves are root filtered, because kSyzygyFastPlayId sets the rep flag. + int PopulateRootMoveLimit(MoveList* root_moves) const; // Returns verbose information about given node, as vector of strings. std::vector GetVerboseStats(Node* node, @@ -188,6 +192,7 @@ class Search { // Cummulative depth of all paths taken in PickNodetoExtend. uint64_t cum_depth_ GUARDED_BY(nodes_mutex_) = 0; std::atomic tb_hits_{0}; + std::atomic root_syzygy_rank_{0}; BestMoveInfo::Callback best_move_callback_; ThinkingInfo::Callback info_callback_; @@ -248,12 +253,11 @@ class SearchWorker { void UpdateCounters(); private: + struct NodeToProcess { - bool IsExtendable() const { return !is_collision && !node->IsTerminal(); } + bool IsExtendable() const { return !is_collision && !node->IsCertain(); } bool IsCollision() const { return is_collision; } - bool CanEvalOutOfOrder() const { - return is_cache_hit || node->IsTerminal(); - } + bool CanEvalOutOfOrder() const { return is_cache_hit || node->IsCertain(); } // The node to extend. Node* node; @@ -288,6 +292,7 @@ class SearchWorker { }; NodeToProcess PickNodeToExtend(int collision_limit); + CertaintyResult EvalPosition(const Node* node, const MoveList& legal_moves, const ChessBoard& board); void ExtendNode(Node* node); bool AddNodeToComputation(Node* node, bool add_if_cached); int PrefetchIntoCache(Node* node, int budget); diff --git a/src/syzygy/syzygy.cc b/src/syzygy/syzygy.cc index b6ce505a4e..95f1c295ca 100644 --- a/src/syzygy/syzygy.cc +++ b/src/syzygy/syzygy.cc @@ -1623,9 +1623,13 @@ int SyzygyTablebase::probe_dtz(const Position& pos, ProbeState* result) { } // Use the DTZ tables to rank root moves. -// -// A return value false indicates that not all probes were successful. -bool SyzygyTablebase::root_probe(const Position& pos, bool has_repeated, +// A return value 0 indicates that not all probes were successful. +// Otherwise best rank is returned: +// 1 draw, 1000 win, -1000 loss. +// If rep flag is set for wins: 1000 - (dtz + cnt50). +// If 50 draw in sight for losses: -1000 + (-dtz + cnt50). + +int SyzygyTablebase::root_probe(const Position& pos, bool has_repeated, std::vector* safe_moves) { ProbeState result; auto root_moves = pos.GetBoard().GenerateLegalMoves(); @@ -1655,14 +1659,14 @@ bool SyzygyTablebase::root_probe(const Position& pos, bool has_repeated, next_pos.GetBoard().GenerateLegalMoves().size() == 0) { dtz = 1; } - if (result == FAIL) return false; + if (result == FAIL) return 0; // Better moves are ranked higher. Certain wins are ranked equally. // Losing moves are ranked equally unless a 50-move draw is in sight. int r = dtz > 0 ? (dtz + cnt50 <= 99 && !rep ? 1000 : 1000 - (dtz + cnt50)) : dtz < 0 ? (-dtz * 2 + cnt50 < 100 ? -1000 : -1000 + (-dtz + cnt50)) - : 0; + : 1; if (r > best_rank) best_rank = r; ranks.push_back(r); } @@ -1674,16 +1678,18 @@ bool SyzygyTablebase::root_probe(const Position& pos, bool has_repeated, } counter++; } - return true; + return best_rank; } // Use the WDL tables to rank root moves. // This is a fallback for the case that some or all DTZ tables are missing. -// -// A return value false indicates that not all probes were successful. -bool SyzygyTablebase::root_probe_wdl(const Position& pos, +// A return value 0 indicates that not all probes were successful. +// Otherwise best rank is returned: +// -1000 loss, -899 blessed loss, 1 draw, 899 cursed win and 1000 win. + +int SyzygyTablebase::root_probe_wdl(const Position& pos, std::vector* safe_moves) { - static const int WDL_to_rank[] = {-1000, -899, 0, 899, 1000}; + static const int WDL_to_rank[] = {-1000, -899, 1, 899, 1000}; auto root_moves = pos.GetBoard().GenerateLegalMoves(); ProbeState result; std::vector ranks; @@ -1693,7 +1699,7 @@ bool SyzygyTablebase::root_probe_wdl(const Position& pos, for (auto& m : root_moves) { Position nextPos = Position(pos, m); WDLScore wdl = static_cast(-probe_wdl(nextPos, &result)); - if (result == FAIL) return false; + if (result == FAIL) return 0; ranks.push_back(WDL_to_rank[wdl + 2]); if (ranks.back() > best_rank) best_rank = ranks.back(); } @@ -1705,6 +1711,6 @@ bool SyzygyTablebase::root_probe_wdl(const Position& pos, } counter++; } - return true; + return best_rank; } } // namespace lczero diff --git a/src/syzygy/syzygy.h b/src/syzygy/syzygy.h index 521c0b46d2..25234eddd0 100644 --- a/src/syzygy/syzygy.h +++ b/src/syzygy/syzygy.h @@ -87,16 +87,22 @@ class SyzygyTablebase { // has_repeated should be whether there are any repeats since last 50 move // counter reset. // Thread safe. - // Returns false if the position is not in the tablebase. - // Safe moves are added to the safe_moves output paramater. - bool root_probe(const Position& pos, bool has_repeated, + // A return value 0 indicates that not all probes were successful. + // Otherwise best rank is returned: + // 1 draw, 1000 win, -1000 loss. + // If rep flag is set for wins 1000 - (dtz + cnt50) win. + // If 50 draw in sight for losses: -1000 + (-dtz + cnt50). + // Safe moves are added to the safe_moves output parameter. + int root_probe(const Position& pos, bool has_repeated, std::vector* safe_moves); // Probes WDL tables to determine which moves might be on the optimal play // path. If 50 move ply counter is non-zero some (or maybe even all) of the // returned safe moves in a 'winning' position, may actually be draws. - // Returns false if the position is not in the tablebase. - // Safe moves are added to the safe_moves output paramater. - bool root_probe_wdl(const Position& pos, std::vector* safe_moves); + // A return value 0 indicates that not all probes were successful. + // Otherwise best rank is returned: + // -1000 loss, -899 blessed loss, 1 draw, 899 cursed win and 1000 win. + // Safe moves are added to the safe_moves output parameter. + int root_probe_wdl(const Position& pos, std::vector* safe_moves); private: template