Skip to content

Commit 8e51faf

Browse files
committed
Fix transposition table, plus some refactoring to search extensions
1 parent 41bbae5 commit 8e51faf

File tree

7 files changed

+72
-68
lines changed

7 files changed

+72
-68
lines changed

src/algorithms/draw_checker.rs

+8-7
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@ use std::collections::HashMap;
22

33
use chess::Board;
44

5-
pub fn uncount_board(board_played_times_prediction: &mut HashMap<Board, u32>, new_board: &Board) {
5+
pub fn uncount_board(board_played_times_prediction: &mut HashMap<u64, u32>, new_board: &Board) {
6+
let hash = new_board.get_hash();
67
board_played_times_prediction.insert(
78
// TODO Hash it to avoid copying, we need a good hash function for Board
8-
*new_board,
9-
*board_played_times_prediction.get(new_board).unwrap_or(&0) - 1,
9+
hash,
10+
*board_played_times_prediction.get(&hash).unwrap_or(&0) - 1,
1011
);
1112
}
1213

13-
pub fn count_board(board_played_times_prediction: &mut HashMap<Board, u32>, new_board: &Board) {
14+
pub fn count_board(board_played_times_prediction: &mut HashMap<u64, u32>, new_board: &Board) {
15+
let hash = new_board.get_hash();
1416
board_played_times_prediction.insert(
15-
// TODO Hash it to avoid copying, we need a good hash function for Board
16-
*new_board,
17-
*board_played_times_prediction.get(new_board).unwrap_or(&0) + 1,
17+
hash,
18+
*board_played_times_prediction.get(&hash).unwrap_or(&0) + 1,
1819
);
1920
}

src/algorithms/the_algorithm.rs

+44-38
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,16 @@ use tokio::time::{Duration, Instant};
66
use crate::algorithms::{draw_checker, eval};
77
use crate::common::constants::{modules::*, naive_psqt_tables::*, tapered_pesto_psqt_tables::*};
88
use crate::common::utils::{self, module_enabled, piece_value, Stats};
9+
use crate::modules::{alpha_beta, analyze};
910
use crate::modules::search_extensions;
1011
use crate::modules::skip_bad_moves;
1112
use crate::modules::transposition_table::{self, TranspositionEntry};
12-
use crate::modules::{alpha_beta, analyze};
1313

1414
use super::utils::Evaluation;
1515

16-
1716
#[derive(Clone, Debug)]
1817
pub(crate) struct Algorithm {
1918
pub(crate) modules: u32,
20-
transposition_table: HashMap<Board, TranspositionEntry>,
2119
pub(crate) time_per_move: Duration,
2220
/// Number of times that a given board has been played
2321
pub(crate) board_played_times: HashMap<Board, u32>,
@@ -48,7 +46,6 @@ impl Algorithm {
4846
pub(crate) fn new(modules: u32, time_per_move: Duration) -> Self {
4947
Self {
5048
modules,
51-
transposition_table: HashMap::with_capacity(45),
5249
time_per_move,
5350
board_played_times: HashMap::new(),
5451
pawn_hash: HashMap::new(),
@@ -72,9 +69,10 @@ impl Algorithm {
7269
deadline: Option<Instant>,
7370
stats: &mut Stats,
7471
num_extensions: u32,
75-
board_played_times_prediction: &mut HashMap<Board, u32>,
72+
board_played_times_prediction: &mut HashMap<u64, u32>,
7673
mut mg_incremental_psqt_eval: f32,
7774
mut eg_incremental_psqt_eval: f32,
75+
transposition_table: &mut HashMap<u64, TranspositionEntry>,
7876
) -> NodeData {
7977
if depth == 0 {
8078
stats.leaves_visited += 1;
@@ -90,15 +88,15 @@ impl Algorithm {
9088
None,
9189
Some(mg_incremental_psqt_eval + eg_incremental_psqt_eval),
9290
);
93-
if module_enabled(self.modules, TRANSPOSITION_TABLE) {
94-
transposition_table::insert_in_transposition_table(
95-
&mut self.transposition_table,
96-
board,
97-
depth,
98-
stats,
99-
evaluation,
100-
);
101-
}
91+
// if module_enabled(self.modules, TRANSPOSITION_TABLE) {
92+
// transposition_table::insert_in_transposition_table(
93+
// transposition_table,
94+
// board,
95+
// depth,
96+
// stats,
97+
// evaluation,
98+
// );
99+
// }
102100

103101
return NodeData::new(evaluation, None);
104102
}
@@ -114,12 +112,16 @@ impl Algorithm {
114112
return NodeData::new(best_evaluation, None);
115113
}
116114

117-
let transposition_table = if module_enabled(self.modules, TRANSPOSITION_TABLE) {
118-
Some(&self.transposition_table)
119-
} else {
120-
None
121-
};
122-
let mut boards = Self::create_board_list(board, stats, legal_moves, transposition_table);
115+
let mut boards = Self::create_board_list(
116+
board,
117+
stats,
118+
legal_moves,
119+
if module_enabled(self.modules, TRANSPOSITION_TABLE) {
120+
Some(transposition_table)
121+
} else {
122+
None
123+
},
124+
);
123125

124126
// Sort by eval
125127
Self::sort_by_eval(maximise, &mut boards);
@@ -146,16 +148,14 @@ impl Algorithm {
146148
return NodeData::new(best_evaluation, None);
147149
}
148150

149-
let search_extensions = module_enabled(self.modules, SEARCH_EXTENSIONS);
150-
let extend_by = search_extensions::calculate(
151-
num_extensions,
152-
num_legal_moves,
153-
new_board,
154-
search_extensions,
155-
);
151+
let extend_by = if module_enabled(self.modules, SEARCH_EXTENSIONS) {
152+
search_extensions::calculate(num_extensions, num_legal_moves, new_board)
153+
} else {
154+
0
155+
};
156156

157-
let evaluation = if let Some(transposition_entry) = transposition_entry {
158-
transposition_entry.evaluation
157+
let evaluation = if transposition_entry.is_some_and(|entry| entry.depth >= depth) {
158+
transposition_entry.unwrap().evaluation
159159
} else {
160160
draw_checker::count_board(board_played_times_prediction, &new_board);
161161
let evaluation = self.node_eval_recursive(
@@ -170,6 +170,7 @@ impl Algorithm {
170170
board_played_times_prediction,
171171
mg_incremental_psqt_eval,
172172
eg_incremental_psqt_eval,
173+
transposition_table,
173174
);
174175
draw_checker::uncount_board(board_played_times_prediction, &new_board);
175176
debug_data = evaluation.debug_data;
@@ -259,9 +260,9 @@ impl Algorithm {
259260
Some(mg_incremental_psqt_eval + eg_incremental_psqt_eval);
260261
}
261262

262-
if module_enabled(self.modules, TRANSPOSITION_TABLE) && depth >= 3 {
263+
if module_enabled(self.modules, TRANSPOSITION_TABLE) {
263264
transposition_table::insert_in_transposition_table(
264-
&mut self.transposition_table,
265+
transposition_table,
265266
board,
266267
depth,
267268
stats,
@@ -287,7 +288,7 @@ impl Algorithm {
287288
board: &Board,
288289
stats: &mut Stats,
289290
legal_moves: MoveGen,
290-
transposition_table: Option<&HashMap<Board, TranspositionEntry>>,
291+
transposition_table: Option<&HashMap<u64, TranspositionEntry>>,
291292
) -> Vec<(ChessMove, Board, Option<TranspositionEntry>)> {
292293
legal_moves
293294
.map(|chess_move| {
@@ -331,6 +332,7 @@ impl Algorithm {
331332
board: &Board,
332333
depth: u32,
333334
deadline: Option<Instant>,
335+
transposition_table: &mut HashMap<u64, TranspositionEntry>,
334336
) -> (Option<Action>, Vec<String>, Stats) {
335337
let mut stats = Stats::default();
336338
let out = self.node_eval_recursive(
@@ -345,6 +347,7 @@ impl Algorithm {
345347
&mut HashMap::new(),
346348
0.,
347349
0.,
350+
transposition_table,
348351
);
349352
let analyzer_data = out.debug_data.unwrap_or_default();
350353
(out.evaluation.next_action, analyzer_data, stats)
@@ -360,13 +363,16 @@ impl Algorithm {
360363
*self.board_played_times.get(board).unwrap_or(&0) + 1,
361364
);
362365

366+
let mut transposition_table = HashMap::new();
363367
// Guarantee that at least the first layer gets done.
364368
const START_DEPTH: u32 = 1;
365-
let mut deepest_complete_output = self.next_action(board, START_DEPTH, None);
369+
let mut deepest_complete_output =
370+
self.next_action(board, START_DEPTH, None, &mut transposition_table);
366371
let mut deepest_complete_depth = START_DEPTH;
367372

368373
for depth in (deepest_complete_depth + 1)..=10 {
369-
let latest_output = self.next_action(board, depth, Some(deadline));
374+
let latest_output =
375+
self.next_action(board, depth, Some(deadline), &mut transposition_table);
370376
if utils::passed_deadline(deadline) {
371377
// The cancelled layer is the one with this data
372378
deepest_complete_output.2.progress_on_next_layer =
@@ -378,7 +384,6 @@ impl Algorithm {
378384
}
379385
}
380386
deepest_complete_output.2.depth = deepest_complete_depth;
381-
deepest_complete_output.2.tt_size = self.transposition_table.len() as u32;
382387

383388
let mut action = match deepest_complete_output.0 {
384389
Some(action) => action,
@@ -410,7 +415,7 @@ impl Algorithm {
410415
pub(crate) fn eval(
411416
&mut self,
412417
board: &Board,
413-
board_played_times_prediction: &HashMap<Board, u32>,
418+
board_played_times_prediction: &HashMap<u64, u32>,
414419
mg_incremental_psqt_eval: f32,
415420
eg_incremental_psqt_eval: f32,
416421
) -> f32 {
@@ -426,7 +431,9 @@ impl Algorithm {
426431
};
427432
}
428433
let board_played_times = *self.board_played_times.get(board).unwrap_or(&0)
429-
+ *board_played_times_prediction.get(board).unwrap_or(&0);
434+
+ *board_played_times_prediction
435+
.get(&board.get_hash())
436+
.unwrap_or(&0);
430437
if board_played_times >= 2 {
431438
// This is third time this is played. Draw by three-fold repetition
432439
return 0.;
@@ -652,7 +659,6 @@ impl Algorithm {
652659
}
653660

654661
pub(crate) fn reset(&mut self) {
655-
self.transposition_table = HashMap::new();
656662
self.board_played_times = HashMap::new();
657663
self.pawn_hash = HashMap::new();
658664
self.naive_psqt_pawn_hash = HashMap::new();

src/common/utils.rs

-4
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ pub(crate) struct Stats {
8989
pub(crate) transposition_table_entries: u32,
9090
pub(crate) transposition_table_accesses: u32,
9191
pub(crate) time_for_transposition_access: Duration,
92-
pub(crate) tt_size: u32,
9392
}
9493

9594
impl AddAssign for Stats {
@@ -105,7 +104,6 @@ impl AddAssign for Stats {
105104
self.transposition_table_entries += rhs.transposition_table_entries;
106105
self.transposition_table_accesses += rhs.transposition_table_accesses;
107106
self.time_for_transposition_access += rhs.time_for_transposition_access;
108-
self.tt_size += rhs.tt_size;
109107
}
110108
}
111109

@@ -125,7 +123,6 @@ impl Div<u32> for Stats {
125123
transposition_table_entries: self.transposition_table_entries as f32 / rhs as f32,
126124
transposition_table_accesses: self.transposition_table_accesses as f32 / rhs as f32,
127125
time_for_transposition_access: self.time_for_transposition_access / rhs,
128-
tt_size: self.tt_size as f32 / rhs as f32,
129126
}
130127
}
131128
}
@@ -144,7 +141,6 @@ pub(crate) struct StatsAverage {
144141
pub(crate) transposition_table_entries: f32,
145142
pub(crate) transposition_table_accesses: f32,
146143
pub(crate) time_for_transposition_access: Duration,
147-
pub(crate) tt_size: f32,
148144
}
149145

150146
pub(crate) fn passed_deadline(deadline: Instant) -> bool {

src/main.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ mod pitter;
2525
//If we should print the moves played and results of each game.
2626
pub(crate) const PRINT_GAME: bool = true;
2727
//If we should test all possible pairs of combinations.
28-
const TEST_ALL_PAIRS: bool = true;
28+
const TEST_ALL_PAIRS: bool = false;
2929

3030
#[tokio::main]
3131
async fn main() {
@@ -38,7 +38,7 @@ async fn main() {
3838
let modules2 = 0;
3939
let time_per_move1 = Duration::from_micros(2000);
4040
let time_per_move2 = Duration::from_micros(2000);
41-
let game_pairs = 50;
41+
let game_pairs = 150;
4242

4343
//Run competition
4444
let result = do_competition(
@@ -57,7 +57,7 @@ async fn main() {
5757
);
5858
let time_per_move1 = Duration::from_micros(2000);
5959
let time_per_move2 = Duration::from_micros(2000);
60-
let game_pairs = 200;
60+
let game_pairs = 400;
6161

6262
let mut competitions_run: u32 = 0;
6363
let mut dp: Vec<Vec<Option<CompetitionResults>>> =

src/modules/search_extensions.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ pub fn calculate(
44
num_extensions: u32,
55
num_legal_moves: usize,
66
new_board: Board,
7-
search_extensions: bool,
87
) -> u32 {
9-
if !search_extensions || num_extensions > 3 {
8+
if num_extensions > 3 {
109
0
11-
} else if num_legal_moves == 1 || new_board.checkers().popcnt() >= 2 {
10+
} else if num_legal_moves <= 3 || new_board.checkers().popcnt() >= 2 {
1211
1
1312
} else {
1413
0

src/modules/transposition_table.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,29 @@ impl TranspositionEntry {
1919
}
2020

2121
pub(crate) fn insert_in_transposition_table(
22-
transposition_table: &mut HashMap<Board, TranspositionEntry>,
22+
transposition_table: &mut HashMap<u64, TranspositionEntry>,
2323
board: &Board,
2424
depth: u32,
2525
stats: &mut Stats,
2626
evaluation: Evaluation,
2727
) {
2828
let start = Instant::now();
29-
transposition_table.insert(*board, TranspositionEntry::new(depth, evaluation));
29+
transposition_table.insert(board.get_hash(), TranspositionEntry::new(depth, evaluation));
3030
stats.time_for_transposition_access += Instant::now() - start;
3131
stats.transposition_table_entries += 1
3232
}
3333

3434
pub(crate) fn get_transposition_entry(
35-
transposition_table: &HashMap<Board, TranspositionEntry>,
35+
transposition_table: &HashMap<u64, TranspositionEntry>,
3636
stats: &mut Stats,
3737
board: &Board,
3838
) -> Option<TranspositionEntry> {
3939
let start = Instant::now();
4040

41-
let transposition_entry = transposition_table.get(board).copied();
42-
41+
let transposition_entry = transposition_table.get(&board.get_hash()).copied();
42+
if transposition_entry.is_some() {
43+
stats.transposition_table_accesses += 1;
44+
}
4345
let time_for_transposition_access = Instant::now() - start;
4446
stats.time_for_transposition_access += time_for_transposition_access;
4547

src/pitter/logic.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,14 @@ impl Competition {
257257
for task in tasks {
258258
let _ = task.await;
259259
}
260-
// let sum_stats = sum_stats.lock().await;
261-
// let avg_stats = (
262-
// sum_stats.0 / sum_stats.0.num_plies,
263-
// sum_stats.1 / sum_stats.1.num_plies,
264-
// );
265-
266-
//println!("Stats for algo1: {:#?}", avg_stats.0);
267-
//println!("Stats for algo2: {:#?}", avg_stats.1);
260+
let sum_stats = sum_stats.lock().await;
261+
let avg_stats = (
262+
sum_stats.0 / sum_stats.0.num_plies,
263+
sum_stats.1 / sum_stats.1.num_plies,
264+
);
265+
266+
println!("Stats for algo1: {:#?}", avg_stats.0);
267+
println!("Stats for algo2: {:#?}", avg_stats.1);
268268

269269
// Gives E0597 otherwise
270270
#[allow(clippy::let_and_return)]

0 commit comments

Comments
 (0)