Skip to content

Commit 97b09e0

Browse files
authored
Add files via upload
New version: enhanced certainty propagation, policy compression, tree balancing, easy early visits and more
1 parent 0be46f2 commit 97b09e0

File tree

5 files changed

+266
-74
lines changed

5 files changed

+266
-74
lines changed

src/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
int main(int argc, const char** argv) {
2626
std::cerr << " _" << std::endl;
2727
std::cerr << "| _ | |" << std::endl;
28-
std::cerr << "|_ |_ |_| built " << __DATE__ << std::endl;
28+
std::cerr << "|_ |_ |_| MCTS EXPERIMENTAL " << __DATE__ << std::endl;
2929
using namespace lczero;
3030
CommandLine::Init(argc, argv);
3131
CommandLine::RegisterMode("uci", "(default) Act as UCI engine");

src/mcts/node.cc

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,11 @@ void Node::ResetStats() {
206206
n_ = 0;
207207
v_ = 0.0;
208208
q_ = 0.0;
209-
w_ = 0.0;
209+
w_ = 0.0; // no longer needed
210210
p_ = 0.0;
211+
m_ = 0.0;
212+
b_ = 0.0;
213+
avg_child_branches_ = 0.0;
211214
max_depth_ = 0;
212215
full_depth_ = 0;
213216
is_terminal_ = false;
@@ -219,7 +222,7 @@ std::string Node::DebugString() const {
219222
oss << "Move: " << move_.as_string() << " Term:" << is_terminal_
220223
<< " This:" << this << " Parent:" << parent_ << " child:" << child_
221224
<< " sibling:" << sibling_ << " P:" << p_ << " Q:" << q_ << " W:" << w_
222-
<< " N:" << n_ << " N_:" << n_in_flight_;
225+
<< " M:" << m_<< " N:" << n_ << " N_:" << n_in_flight_;
223226
return oss.str();
224227
}
225228

@@ -243,6 +246,7 @@ void Node::MakeCertain(float q) {
243246
is_certain_ = true;
244247
//n_ = UINT32_MAX;
245248
}
249+
// mainly used for root
246250
void Node::UnCertain() {
247251
is_certain_ = false;
248252
}
@@ -255,21 +259,28 @@ bool Node::TryStartScoreUpdate() {
255259
void Node::CancelScoreUpdate() { --n_in_flight_; }
256260

257261
void Node::FinalizeScoreUpdate(float v, float kBackpropagate, int kAutoextend) {
262+
float q_new = 0;
258263
if (!is_certain_)
259264
{
260-
// gamma = 1.00 corresponds to base leela q update
261-
// gamma = 0.75 is my first guess and yields +25 Elo
262-
// lots of room for tuning, optimum probably 0.8-09
263-
q_ += (v - q_) / (std::powf(n_, kBackpropagate) + 1);
264-
// normal leela update in q + (v-q) format
265-
// q_ += (v - q_) / ((float)n_ + 1.0f);
266-
265+
// Update Q:
266+
// Alternatives:
267+
// q_ += (v - q_) / (std::powf(n_, kBackpropagate) + 1);
268+
// q_ += (v - q_) / ((n_ * kBackpropagate) + 1);
269+
q_new = q_ + (v - q_) / ((float)n_ + 1.0f); //Standard update
270+
271+
// Update M:
272+
// Sum of Squared Difference for Variance Computation
273+
// this is a numerically stable online method
274+
m_ += (v - q_)*(v - q_new);
275+
276+
q_ = q_new;
267277
// all nodes have v_ set to NN compute or to certain value
268278
// except nodes with only one legal move (auto-extend=1)
269279
// these must inherit from child
270280
if (HasOnlyOneChild() && (kAutoextend == 1)) {
271281
q_ = -(child_->q_);
272282
v_ = -(child_->v_);
283+
m_ = child_->m_;
273284
}
274285

275286
// Increment N.

src/mcts/node.h

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class Node {
5555
Node* GetParent() const { return parent_; }
5656

5757
// Gets first child
58-
5958
Node* GetFirstChild() const { return child_; }
6059

6160
// Returns whether a node has children.
@@ -80,26 +79,40 @@ class Node {
8079
// Returns Q if number of visits is more than 0,
8180
float GetQ(float default_q) const { return n_ ? q_ : default_q; }
8281
// Returns U / (Puct * N[parent])
83-
float GetU() const { return p_/ (1 + n_ + n_in_flight_); }
82+
float GetU() const { return p_ / (1 + n_ + n_in_flight_); }
8483
// Returns value of Value Head returned from the neural net.
8584
float GetV() const { return v_; }
85+
// Returns the avg. of children branches (not expanded grandchildren)
86+
float GetCB() const { return avg_child_branches_; }
8687
// Returns value of Move probabilityreturned from the neural net.
8788
// (but can be changed by adding Dirichlet noise).
8889
float GetP() const { return p_; }
90+
// returns branches of this node (number of childs)
91+
float GetB() const {return b_;}
92+
// Returns population variance of q.
93+
float GetSigma2(float default_m) const { return n_>2 ? m_/(n_-1):default_m; }
8994
// Returns whether the node is known to be draw/lose/win.
9095
bool IsTerminal() const { return is_terminal_; }
9196
// Returns whether the node is known to have a certain score
9297
bool IsCertain() const { return is_certain_; }
93-
94-
95-
uint16_t GetFullDepth() const { return full_depth_; }
96-
uint16_t GetMaxDepth() const { return max_depth_; }
98+
uint16_t GetFullDepth() const { return full_depth_; }
99+
uint16_t GetMaxDepth() const { return max_depth_; }
97100
// makes node uncertain again (used to make root uncertain when search is initialized
98101
void UnCertain();
102+
// Sets node avg of all childrens branches
103+
void SetCB(float val) { avg_child_branches_ = val; }
99104
// Sets node own value (from neural net or win/draw/lose adjudication).
100105
void SetV(float val) { v_ = val; }
101106
// Sets move probability.
102107
void SetP(float val) { p_ = val; }
108+
// Sets Q
109+
void SetQ(float val) { q_ = val; }
110+
// Sets branches (number of childs)
111+
void SetB(float val) { b_ = val; }
112+
// Sets n_ for terminal nodes that are
113+
// found when creating children in
114+
// expand node
115+
void SetN1() { n_ = 1; }
103116
// Makes the node terminal and sets it's score.
104117
void MakeTerminal(GameResult result);
105118
// Makes the node certain and sets it's score
@@ -117,9 +130,11 @@ class Node {
117130
// Updates:
118131
// * N (+=1)
119132
// * N-in-flight (-=1)
120-
// * W (+= v)
121-
// * Q (=w/n)
122-
// Backpropagete and Autoextend modes are currently passed as parameters
133+
// * W (+= v) obsolete
134+
// * Q (+= q + (v - q) (n_+1))
135+
// * M Sum of Squares of Differences
136+
// kBackpropagate (not used currently) and Autoextend modes are
137+
// currently passed as parameters
123138
// will either be removed if changes become permanent, or replaced
124139
// by a weight parameter.
125140
void FinalizeScoreUpdate(float v, float kBackpropagate, int kAutoextend);
@@ -171,6 +186,11 @@ class Node {
171186
// Probabality that this move will be made. From policy head of the neural
172187
// network.
173188
float p_;
189+
// Sum of Squares of Differences from current mean
190+
float m_;
191+
// branch data for tree shaping
192+
float b_;
193+
float avg_child_branches_;
174194
// How many completed visits this node had.
175195
uint32_t n_;
176196
// (aka virtual loss). How many threads currently process this node (started

0 commit comments

Comments
 (0)