Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a "valuehead" search algorithm. #2121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/chess/callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,34 @@ struct ThinkingInfo {
// Hash fullness * 1000
int hashfull = -1;
// Moves to mate.
std::optional<int> mate;
std::optional<int> mate = std::nullopt;
// Win in centipawns.
std::optional<int> score;
std::optional<int> score = std::nullopt;
// Win/Draw/Lose probability * 1000.
struct WDL {
int w;
int d;
int l;
};
std::optional<WDL> wdl;
std::optional<WDL> wdl = std::nullopt;
// Number of successful TB probes (not the same as playouts ending in TB hit).
int tb_hits = -1;
// Best line found. Moves are from perspective of white player.
std::vector<Move> pv;
std::vector<Move> pv = {};
// Multipv index.
int multipv = -1;
// Freeform comment.
std::string comment;
std::string comment = "";

// Those are extensions and not really UCI protocol.
// 1 if it's "player1", 2 if it's "player2"
int player = -1;
// Index of the game in the tournament (0-based).
int game_id = -1;
// The color of the player, if known.
std::optional<bool> is_black;
std::optional<bool> is_black = std::nullopt;
// Moves left
std::optional<int> moves_left;
std::optional<int> moves_left = std::nullopt;
};

// Is sent when a single game is finished.
Expand Down
2 changes: 2 additions & 0 deletions src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class PositionHistory {
PositionHistory() = default;
PositionHistory(const PositionHistory& other) = default;
PositionHistory(PositionHistory&& other) = default;
PositionHistory(std::span<const Position> positions)
: positions_(positions.begin(), positions.end()) {}

PositionHistory& operator=(const PositionHistory& other) = default;
PositionHistory& operator=(PositionHistory&& other) = default;
Expand Down
78 changes: 0 additions & 78 deletions src/engine_classic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ const OptionId kStrictUciTiming{"strict-uci-timing", "StrictTiming",
"The UCI host compensates for lag, waits for "
"the 'readyok' reply before sending 'go' and "
"only then starts timing."};
const OptionId kValueOnly{
"value-only", "ValueOnly",
"In value only mode all search parameters are ignored and the position is "
"evaluated by getting the valuation of every child position and choosing "
"the worst for the opponent."};
const OptionId kClearTree{"", "ClearTree",
"Clear the tree before the next search."};

Expand Down Expand Up @@ -123,7 +118,6 @@ void EngineClassic::PopulateOptions(OptionsParser* options) {
options->Add<BoolOption>(kStrictUciTiming) = false;
options->HideOption(kStrictUciTiming);

options->Add<BoolOption>(kValueOnly) = false;
options->Add<ButtonOption>(kClearTree);
options->HideOption(kClearTree);
}
Expand Down Expand Up @@ -265,74 +259,6 @@ class PonderResponseTransformer : public TransformingUciResponder {
std::string ponder_move_;
};

void ValueOnlyGo(classic::NodeTree* tree, Network* network,
const OptionsDict& options,
std::unique_ptr<UciResponder> responder) {
auto input_format = network->GetCapabilities().input_format;

const auto& board = tree->GetPositionHistory().Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();
tree->GetCurrentHead()->CreateEdges(legal_moves);
PositionHistory history = tree->GetPositionHistory();
std::vector<InputPlanes> planes;
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
if (history.ComputeGameResult() == GameResult::UNDECIDED) {
planes.emplace_back(EncodePositionForNN(
input_format, history, 8, FillEmptyHistory::FEN_ONLY, nullptr));
}
history.Pop();
}

std::vector<float> comp_q;
int batch_size = options.Get<int>(classic::SearchParams::kMiniBatchSizeId);
if (batch_size == 0) batch_size = network->GetMiniBatchSize();

for (size_t i = 0; i < planes.size(); i += batch_size) {
auto comp = network->NewComputation();
for (int j = 0; j < batch_size; j++) {
comp->AddInput(std::move(planes[i + j]));
if (i + j + 1 == planes.size()) break;
}
comp->ComputeBlocking();

for (int j = 0; j < batch_size; j++) comp_q.push_back(comp->GetQVal(j));
}

Move best;
int comp_idx = 0;
float max_q = std::numeric_limits<float>::lowest();
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
auto result = history.ComputeGameResult();
float q = -1;
if (result == GameResult::UNDECIDED) {
// NN eval is for side to move perspective - so if its good, its bad for
// us.
q = -comp_q[comp_idx];
comp_idx++;
} else if (result == GameResult::DRAW) {
q = 0;
} else {
// A legal move to a non-drawn terminal without tablebases must be a
// win.
q = 1;
}
if (q >= max_q) {
max_q = q;
best = edge.GetMove(tree->GetPositionHistory().IsBlackToMove());
}
history.Pop();
}
std::vector<ThinkingInfo> infos;
ThinkingInfo thinking;
thinking.depth = 1;
infos.push_back(thinking);
responder->OutputThinkingInfo(&infos);
BestMoveInfo info(best);
responder->OutputBestMove(&info);
}

} // namespace

void EngineClassic::Go(const GoParams& params) {
Expand Down Expand Up @@ -374,10 +300,6 @@ void EngineClassic::Go(const GoParams& params) {
// Strip movesleft information from the response.
responder = std::make_unique<MovesLeftResponseFilter>(std::move(responder));
}
if (options_.Get<bool>(kValueOnly)) {
ValueOnlyGo(tree_.get(), network_.get(), options_, std::move(responder));
return;
}

if (options_.Get<Button>(kClearTree).TestAndReset()) {
tree_->TrimTreeAtHead();
Expand Down
2 changes: 1 addition & 1 deletion src/neural/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct EvalResultPtr {
float* q = nullptr;
float* d = nullptr;
float* m = nullptr;
std::span<float> p;
std::span<float> p = {};
};

struct EvalPosition {
Expand Down
79 changes: 79 additions & 0 deletions src/search/instamove/instamove.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "search/instamove/instamove.h"

#include <algorithm>
#include <cmath>
#include <vector>

#include "chess/gamestate.h"
Expand Down Expand Up @@ -68,7 +69,85 @@ class PolicyHeadFactory : public SearchFactory {
}
};

class ValueHeadSearch : public InstamoveSearch {
public:
ValueHeadSearch(const SearchContext& context, const GameState& game_state)
: InstamoveSearch(context), game_state_(game_state) {}

Move GetBestMove() final {
std::unique_ptr<BackendComputation> computation =
backend()->CreateComputation();

PositionHistory history(game_state_.GetPositions());
const ChessBoard& board = history.Last().GetBoard();
const std::vector<Move> legal_moves = board.GenerateLegalMoves();
std::vector<EvalResult> results(legal_moves.size());

for (size_t i = 0; i < legal_moves.size(); i++) {
Move move = legal_moves[i];
history.Append(move);
switch (history.ComputeGameResult()) {
case GameResult::UNDECIDED:
computation->AddInput(
EvalPosition{history.GetPositions(), {}},
EvalResultPtr{.q = &results[i].q, .d = &results[i].d});
break;
case GameResult::DRAW:
results[i].q = 0;
results[i].d = 1;
break;
default:
// A legal move to a non-drawn terminal without tablebases must be a
// win.
results[i].q = -1;
results[i].d = 0;
}
history.Pop();
}

computation->ComputeBlocking();

const size_t best_idx =
std::min_element(results.begin(), results.end(),
[](const EvalResult& a, const EvalResult& b) {
return a.q < b.q;
}) -
results.begin();

std::vector<ThinkingInfo> infos = {{
.depth = 1,
.seldepth = 1,
.nodes = static_cast<int64_t>(legal_moves.size()),
.score = 90 * std::tan(1.5637541897 * results[best_idx].q),
.wdl =
ThinkingInfo::WDL{
static_cast<int>(std::round(
500 * (1 + results[best_idx].q - results[best_idx].d))),
static_cast<int>(std::round(1000 * results[best_idx].d)),
static_cast<int>(std::round(
500 * (1 - results[best_idx].q - results[best_idx].d)))},
}};
uci_responder()->OutputThinkingInfo(&infos);
Move best_move = legal_moves[best_idx];
if (history.IsBlackToMove()) best_move.Mirror();
return best_move;
}

private:
const GameState game_state_;
};

class ValueHeadFactory : public SearchFactory {
std::string_view GetName() const override { return "valuehead"; }
std::unique_ptr<SearchEnvironment> CreateEnvironment(
UciResponder* responder, const OptionsDict* options) const override {
return std::make_unique<InstamoveEnvironment<ValueHeadSearch>>(responder,
options);
}
};

REGISTER_SEARCH(PolicyHeadFactory);
REGISTER_SEARCH(ValueHeadFactory);

} // namespace instamove
} // namespace lczero
6 changes: 3 additions & 3 deletions src/search/register.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class SearchManager {
std::vector<std::unique_ptr<SearchFactory>> algorithms_;
};

#define REGISTER_SEARCH(alg) \
namespace { \
static SearchManager::Register reg3b50Y##algorithm(std::make_unique<alg>()); \
#define REGISTER_SEARCH(alg) \
namespace { \
static SearchManager::Register reg3b50Y_##alg(std::make_unique<alg>()); \
}

} // namespace lczero