Skip to content

Commit

Permalink
Make ValueHead a separate search algorithm.
Browse files Browse the repository at this point in the history
  • Loading branch information
mooskagh committed Feb 7, 2025
1 parent 49d9f12 commit 3da33ef
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 85 deletions.
14 changes: 7 additions & 7 deletions src/chess/callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,34 @@ struct ThinkingInfo {
// Hash fullness * 1000
int hashfull = -1;
// Moves to mate.
std::optional<int> mate;
std::optional<int> mate = std::nullopt;
// Win in centipawns.
std::optional<int> score;
std::optional<int> score = std::nullopt;
// Win/Draw/Lose probability * 1000.
struct WDL {
int w;
int d;
int l;
};
std::optional<WDL> wdl;
std::optional<WDL> wdl = std::nullopt;
// Number of successful TB probes (not the same as playouts ending in TB hit).
int tb_hits = -1;
// Best line found. Moves are from perspective of white player.
std::vector<Move> pv;
std::vector<Move> pv = {};
// Multipv index.
int multipv = -1;
// Freeform comment.
std::string comment;
std::string comment = "";

// Those are extensions and not really UCI protocol.
// 1 if it's "player1", 2 if it's "player2"
int player = -1;
// Index of the game in the tournament (0-based).
int game_id = -1;
// The color of the player, if known.
std::optional<bool> is_black;
std::optional<bool> is_black = std::nullopt;
// Moves left
std::optional<int> moves_left;
std::optional<int> moves_left = std::nullopt;
};

// Is sent when a single game is finished.
Expand Down
2 changes: 2 additions & 0 deletions src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class PositionHistory {
PositionHistory() = default;
PositionHistory(const PositionHistory& other) = default;
PositionHistory(PositionHistory&& other) = default;
PositionHistory(std::span<const Position> positions)
: positions_(positions.begin(), positions.end()) {}

PositionHistory& operator=(const PositionHistory& other) = default;
PositionHistory& operator=(PositionHistory&& other) = default;
Expand Down
78 changes: 0 additions & 78 deletions src/engine_classic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ const OptionId kStrictUciTiming{"strict-uci-timing", "StrictTiming",
"The UCI host compensates for lag, waits for "
"the 'readyok' reply before sending 'go' and "
"only then starts timing."};
const OptionId kValueOnly{
"value-only", "ValueOnly",
"In value only mode all search parameters are ignored and the position is "
"evaluated by getting the valuation of every child position and choosing "
"the worst for the opponent."};
const OptionId kClearTree{"", "ClearTree",
"Clear the tree before the next search."};

Expand Down Expand Up @@ -123,7 +118,6 @@ void EngineClassic::PopulateOptions(OptionsParser* options) {
options->Add<BoolOption>(kStrictUciTiming) = false;
options->HideOption(kStrictUciTiming);

options->Add<BoolOption>(kValueOnly) = false;
options->Add<ButtonOption>(kClearTree);
options->HideOption(kClearTree);
}
Expand Down Expand Up @@ -265,74 +259,6 @@ class PonderResponseTransformer : public TransformingUciResponder {
std::string ponder_move_;
};

void ValueOnlyGo(classic::NodeTree* tree, Network* network,
const OptionsDict& options,
std::unique_ptr<UciResponder> responder) {
auto input_format = network->GetCapabilities().input_format;

const auto& board = tree->GetPositionHistory().Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();
tree->GetCurrentHead()->CreateEdges(legal_moves);
PositionHistory history = tree->GetPositionHistory();
std::vector<InputPlanes> planes;
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
if (history.ComputeGameResult() == GameResult::UNDECIDED) {
planes.emplace_back(EncodePositionForNN(
input_format, history, 8, FillEmptyHistory::FEN_ONLY, nullptr));
}
history.Pop();
}

std::vector<float> comp_q;
int batch_size = options.Get<int>(classic::SearchParams::kMiniBatchSizeId);
if (batch_size == 0) batch_size = network->GetMiniBatchSize();

for (size_t i = 0; i < planes.size(); i += batch_size) {
auto comp = network->NewComputation();
for (int j = 0; j < batch_size; j++) {
comp->AddInput(std::move(planes[i + j]));
if (i + j + 1 == planes.size()) break;
}
comp->ComputeBlocking();

for (int j = 0; j < batch_size; j++) comp_q.push_back(comp->GetQVal(j));
}

Move best;
int comp_idx = 0;
float max_q = std::numeric_limits<float>::lowest();
for (auto edge : tree->GetCurrentHead()->Edges()) {
history.Append(edge.GetMove());
auto result = history.ComputeGameResult();
float q = -1;
if (result == GameResult::UNDECIDED) {
// NN eval is for side to move perspective - so if its good, its bad for
// us.
q = -comp_q[comp_idx];
comp_idx++;
} else if (result == GameResult::DRAW) {
q = 0;
} else {
// A legal move to a non-drawn terminal without tablebases must be a
// win.
q = 1;
}
if (q >= max_q) {
max_q = q;
best = edge.GetMove(tree->GetPositionHistory().IsBlackToMove());
}
history.Pop();
}
std::vector<ThinkingInfo> infos;
ThinkingInfo thinking;
thinking.depth = 1;
infos.push_back(thinking);
responder->OutputThinkingInfo(&infos);
BestMoveInfo info(best);
responder->OutputBestMove(&info);
}

} // namespace

void EngineClassic::Go(const GoParams& params) {
Expand Down Expand Up @@ -374,10 +300,6 @@ void EngineClassic::Go(const GoParams& params) {
// Strip movesleft information from the response.
responder = std::make_unique<MovesLeftResponseFilter>(std::move(responder));
}
if (options_.Get<bool>(kValueOnly)) {
ValueOnlyGo(tree_.get(), network_.get(), options_, std::move(responder));
return;
}

if (options_.Get<Button>(kClearTree).TestAndReset()) {
tree_->TrimTreeAtHead();
Expand Down
77 changes: 77 additions & 0 deletions src/search/instamove/instamove.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "search/instamove/instamove.h"

#include <algorithm>
#include <cmath>
#include <vector>

#include "chess/gamestate.h"
Expand Down Expand Up @@ -68,7 +69,83 @@ class PolicyHeadFactory : public SearchFactory {
}
};

class ValueHeadSearch : public InstamoveSearch {
public:
ValueHeadSearch(const SearchContext& context, const GameState& game_state)
: InstamoveSearch(context), game_state_(game_state) {}

Move GetBestMove() final {
std::unique_ptr<BackendComputation> computation =
backend()->CreateComputation();

PositionHistory history(game_state_.GetPositions());
const ChessBoard& board = history.Last().GetBoard();
const std::vector<Move> legal_moves = board.GenerateLegalMoves();
std::vector<EvalResult> results(legal_moves.size());

for (size_t i = 0; i < legal_moves.size(); i++) {
Move move = legal_moves[i];
history.Append(move);
switch (history.ComputeGameResult()) {
case GameResult::UNDECIDED:
computation->AddInput(
EvalPosition{history.GetPositions(), {}},
EvalResultPtr{.q = &results[i].q, .d = &results[i].d});
break;
case GameResult::DRAW:
results[i].q = 0;
results[i].d = 1;
break;
default:
// A legal move to a non-drawn terminal without tablebases must be a
// win.
results[i].q = -1;
results[i].d = 0;
}
history.Pop();
}

computation->ComputeBlocking();

const size_t best_idx =
std::min_element(results.begin(), results.end(),
[](const EvalResult& a, const EvalResult& b) {
return a.q < b.q;
}) -
results.begin();

std::vector<ThinkingInfo> infos = {{
.depth = 1,
.seldepth = 1,
.nodes = static_cast<int64_t>(legal_moves.size()),
.score = 90 * std::tan(1.5637541897 * results[best_idx].q),
.wdl =
ThinkingInfo::WDL{
static_cast<int>(std::round(
500 * (1 + results[best_idx].q - results[best_idx].d))),
static_cast<int>(std::round(1000 * results[best_idx].d)),
static_cast<int>(std::round(
500 * (1 - results[best_idx].q - results[best_idx].d)))},
}};
uci_responder()->OutputThinkingInfo(&infos);
return legal_moves[best_idx];
}

private:
const GameState game_state_;
};

class ValueHeadFactory : public SearchFactory {
std::string_view GetName() const override { return "valuehead"; }
std::unique_ptr<SearchEnvironment> CreateEnvironment(
UciResponder* responder, const OptionsDict* options) const override {
return std::make_unique<InstamoveEnvironment<ValueHeadSearch>>(responder,
options);
}
};

REGISTER_SEARCH(PolicyHeadFactory);
REGISTER_SEARCH(ValueHeadFactory);

} // namespace instamove
} // namespace lczero

0 comments on commit 3da33ef

Please sign in to comment.