Skip to content

Commit 3da33ef

Browse files
committed
Make ValueHead a separate search algorithm.
1 parent 49d9f12 commit 3da33ef

File tree

4 files changed

+86
-85
lines changed

4 files changed

+86
-85
lines changed

src/chess/callbacks.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,34 +68,34 @@ struct ThinkingInfo {
6868
// Hash fullness * 1000
6969
int hashfull = -1;
7070
// Moves to mate.
71-
std::optional<int> mate;
71+
std::optional<int> mate = std::nullopt;
7272
// Win in centipawns.
73-
std::optional<int> score;
73+
std::optional<int> score = std::nullopt;
7474
// Win/Draw/Lose probability * 1000.
7575
struct WDL {
7676
int w;
7777
int d;
7878
int l;
7979
};
80-
std::optional<WDL> wdl;
80+
std::optional<WDL> wdl = std::nullopt;
8181
// Number of successful TB probes (not the same as playouts ending in TB hit).
8282
int tb_hits = -1;
8383
// Best line found. Moves are from perspective of white player.
84-
std::vector<Move> pv;
84+
std::vector<Move> pv = {};
8585
// Multipv index.
8686
int multipv = -1;
8787
// Freeform comment.
88-
std::string comment;
88+
std::string comment = "";
8989

9090
// Those are extensions and not really UCI protocol.
9191
// 1 if it's "player1", 2 if it's "player2"
9292
int player = -1;
9393
// Index of the game in the tournament (0-based).
9494
int game_id = -1;
9595
// The color of the player, if known.
96-
std::optional<bool> is_black;
96+
std::optional<bool> is_black = std::nullopt;
9797
// Moves left
98-
std::optional<int> moves_left;
98+
std::optional<int> moves_left = std::nullopt;
9999
};
100100

101101
// Is sent when a single game is finished.

src/chess/position.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class PositionHistory {
9898
PositionHistory() = default;
9999
PositionHistory(const PositionHistory& other) = default;
100100
PositionHistory(PositionHistory&& other) = default;
101+
PositionHistory(std::span<const Position> positions)
102+
: positions_(positions.begin(), positions.end()) {}
101103

102104
PositionHistory& operator=(const PositionHistory& other) = default;
103105
PositionHistory& operator=(PositionHistory&& other) = default;

src/engine_classic.cc

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ const OptionId kStrictUciTiming{"strict-uci-timing", "StrictTiming",
6161
"The UCI host compensates for lag, waits for "
6262
"the 'readyok' reply before sending 'go' and "
6363
"only then starts timing."};
64-
const OptionId kValueOnly{
65-
"value-only", "ValueOnly",
66-
"In value only mode all search parameters are ignored and the position is "
67-
"evaluated by getting the valuation of every child position and choosing "
68-
"the worst for the opponent."};
6964
const OptionId kClearTree{"", "ClearTree",
7065
"Clear the tree before the next search."};
7166

@@ -123,7 +118,6 @@ void EngineClassic::PopulateOptions(OptionsParser* options) {
123118
options->Add<BoolOption>(kStrictUciTiming) = false;
124119
options->HideOption(kStrictUciTiming);
125120

126-
options->Add<BoolOption>(kValueOnly) = false;
127121
options->Add<ButtonOption>(kClearTree);
128122
options->HideOption(kClearTree);
129123
}
@@ -265,74 +259,6 @@ class PonderResponseTransformer : public TransformingUciResponder {
265259
std::string ponder_move_;
266260
};
267261

268-
void ValueOnlyGo(classic::NodeTree* tree, Network* network,
269-
const OptionsDict& options,
270-
std::unique_ptr<UciResponder> responder) {
271-
auto input_format = network->GetCapabilities().input_format;
272-
273-
const auto& board = tree->GetPositionHistory().Last().GetBoard();
274-
auto legal_moves = board.GenerateLegalMoves();
275-
tree->GetCurrentHead()->CreateEdges(legal_moves);
276-
PositionHistory history = tree->GetPositionHistory();
277-
std::vector<InputPlanes> planes;
278-
for (auto edge : tree->GetCurrentHead()->Edges()) {
279-
history.Append(edge.GetMove());
280-
if (history.ComputeGameResult() == GameResult::UNDECIDED) {
281-
planes.emplace_back(EncodePositionForNN(
282-
input_format, history, 8, FillEmptyHistory::FEN_ONLY, nullptr));
283-
}
284-
history.Pop();
285-
}
286-
287-
std::vector<float> comp_q;
288-
int batch_size = options.Get<int>(classic::SearchParams::kMiniBatchSizeId);
289-
if (batch_size == 0) batch_size = network->GetMiniBatchSize();
290-
291-
for (size_t i = 0; i < planes.size(); i += batch_size) {
292-
auto comp = network->NewComputation();
293-
for (int j = 0; j < batch_size; j++) {
294-
comp->AddInput(std::move(planes[i + j]));
295-
if (i + j + 1 == planes.size()) break;
296-
}
297-
comp->ComputeBlocking();
298-
299-
for (int j = 0; j < batch_size; j++) comp_q.push_back(comp->GetQVal(j));
300-
}
301-
302-
Move best;
303-
int comp_idx = 0;
304-
float max_q = std::numeric_limits<float>::lowest();
305-
for (auto edge : tree->GetCurrentHead()->Edges()) {
306-
history.Append(edge.GetMove());
307-
auto result = history.ComputeGameResult();
308-
float q = -1;
309-
if (result == GameResult::UNDECIDED) {
310-
// NN eval is for side to move perspective - so if its good, its bad for
311-
// us.
312-
q = -comp_q[comp_idx];
313-
comp_idx++;
314-
} else if (result == GameResult::DRAW) {
315-
q = 0;
316-
} else {
317-
// A legal move to a non-drawn terminal without tablebases must be a
318-
// win.
319-
q = 1;
320-
}
321-
if (q >= max_q) {
322-
max_q = q;
323-
best = edge.GetMove(tree->GetPositionHistory().IsBlackToMove());
324-
}
325-
history.Pop();
326-
}
327-
std::vector<ThinkingInfo> infos;
328-
ThinkingInfo thinking;
329-
thinking.depth = 1;
330-
infos.push_back(thinking);
331-
responder->OutputThinkingInfo(&infos);
332-
BestMoveInfo info(best);
333-
responder->OutputBestMove(&info);
334-
}
335-
336262
} // namespace
337263

338264
void EngineClassic::Go(const GoParams& params) {
@@ -374,10 +300,6 @@ void EngineClassic::Go(const GoParams& params) {
374300
// Strip movesleft information from the response.
375301
responder = std::make_unique<MovesLeftResponseFilter>(std::move(responder));
376302
}
377-
if (options_.Get<bool>(kValueOnly)) {
378-
ValueOnlyGo(tree_.get(), network_.get(), options_, std::move(responder));
379-
return;
380-
}
381303

382304
if (options_.Get<Button>(kClearTree).TestAndReset()) {
383305
tree_->TrimTreeAtHead();

src/search/instamove/instamove.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "search/instamove/instamove.h"
2929

3030
#include <algorithm>
31+
#include <cmath>
3132
#include <vector>
3233

3334
#include "chess/gamestate.h"
@@ -68,7 +69,83 @@ class PolicyHeadFactory : public SearchFactory {
6869
}
6970
};
7071

72+
class ValueHeadSearch : public InstamoveSearch {
73+
public:
74+
ValueHeadSearch(const SearchContext& context, const GameState& game_state)
75+
: InstamoveSearch(context), game_state_(game_state) {}
76+
77+
Move GetBestMove() final {
78+
std::unique_ptr<BackendComputation> computation =
79+
backend()->CreateComputation();
80+
81+
PositionHistory history(game_state_.GetPositions());
82+
const ChessBoard& board = history.Last().GetBoard();
83+
const std::vector<Move> legal_moves = board.GenerateLegalMoves();
84+
std::vector<EvalResult> results(legal_moves.size());
85+
86+
for (size_t i = 0; i < legal_moves.size(); i++) {
87+
Move move = legal_moves[i];
88+
history.Append(move);
89+
switch (history.ComputeGameResult()) {
90+
case GameResult::UNDECIDED:
91+
computation->AddInput(
92+
EvalPosition{history.GetPositions(), {}},
93+
EvalResultPtr{.q = &results[i].q, .d = &results[i].d});
94+
break;
95+
case GameResult::DRAW:
96+
results[i].q = 0;
97+
results[i].d = 1;
98+
break;
99+
default:
100+
// A legal move to a non-drawn terminal without tablebases must be a
101+
// win.
102+
results[i].q = -1;
103+
results[i].d = 0;
104+
}
105+
history.Pop();
106+
}
107+
108+
computation->ComputeBlocking();
109+
110+
const size_t best_idx =
111+
std::min_element(results.begin(), results.end(),
112+
[](const EvalResult& a, const EvalResult& b) {
113+
return a.q < b.q;
114+
}) -
115+
results.begin();
116+
117+
std::vector<ThinkingInfo> infos = {{
118+
.depth = 1,
119+
.seldepth = 1,
120+
.nodes = static_cast<int64_t>(legal_moves.size()),
121+
.score = 90 * std::tan(1.5637541897 * results[best_idx].q),
122+
.wdl =
123+
ThinkingInfo::WDL{
124+
static_cast<int>(std::round(
125+
500 * (1 + results[best_idx].q - results[best_idx].d))),
126+
static_cast<int>(std::round(1000 * results[best_idx].d)),
127+
static_cast<int>(std::round(
128+
500 * (1 - results[best_idx].q - results[best_idx].d)))},
129+
}};
130+
uci_responder()->OutputThinkingInfo(&infos);
131+
return legal_moves[best_idx];
132+
}
133+
134+
private:
135+
const GameState game_state_;
136+
};
137+
138+
class ValueHeadFactory : public SearchFactory {
139+
std::string_view GetName() const override { return "valuehead"; }
140+
std::unique_ptr<SearchEnvironment> CreateEnvironment(
141+
UciResponder* responder, const OptionsDict* options) const override {
142+
return std::make_unique<InstamoveEnvironment<ValueHeadSearch>>(responder,
143+
options);
144+
}
145+
};
146+
71147
REGISTER_SEARCH(PolicyHeadFactory);
148+
REGISTER_SEARCH(ValueHeadFactory);
72149

73150
} // namespace instamove
74151
} // namespace lczero

0 commit comments

Comments
 (0)