diff --git a/CMakeLists.txt b/CMakeLists.txt index 66d66dea..d6158b06 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,7 +146,7 @@ set(srcs src/util/bit.hpp src/util/parse.hpp src/util/pretty.hpp - src/util/large_pages.hpp + src/util/mem.hpp src/util/static_vector.hpp src/util/types.hpp src/util/vec/sse2.hpp diff --git a/src/position.cpp b/src/position.cpp index 738cc2ae..0d3fda66 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -4,6 +4,7 @@ #include "common.hpp" #include "geometry.hpp" #include "psqt_state.hpp" +#include "util/mem.hpp" #include "util/parse.hpp" #include "util/types.hpp" #include "zobrist.hpp" @@ -366,7 +367,7 @@ void Position::add_attacks(bool color, PieceId id, Square sq, PieceType ptype, m } template -Position Position::move(Move m, PsqtState* psqtState) const { +Position Position::move(Move m, PsqtState* psqtState, const TT* tt) const { Position new_pos = *this; PsqtUpdates updates{}; @@ -514,6 +515,11 @@ Position Position::move(Move m, PsqtState* psqtState) const { new_pos.m_rook_info[0].as_index() | (new_pos.m_rook_info[1].as_index() << 2); new_pos.m_hash_key ^= Zobrist::castling_zobrist[new_castle_index]; + // Prefetch hash key tt entry + if (tt != nullptr) { + prefetch(tt->addr_key(new_pos.m_hash_key)); + } + new_pos.m_active_color = invert(m_active_color); new_pos.m_ply++; @@ -524,8 +530,8 @@ Position Position::move(Move m, PsqtState* psqtState) const { return new_pos; } -template Position Position::move(Move m, PsqtState* psqtState) const; -template Position Position::move(Move m, PsqtState* psqtState) const; +template Position Position::move(Move m, PsqtState* psqtState, const TT* tt = nullptr) const; +template Position Position::move(Move m, PsqtState* psqtState, const TT* tt = nullptr) const; Position Position::null_move() const { Position new_pos = *this; diff --git a/src/position.hpp b/src/position.hpp index 92fa28e7..249014ed 100644 --- a/src/position.hpp +++ b/src/position.hpp @@ -3,6 +3,7 @@ #include "board.hpp" #include "move.hpp" #include "square.hpp" +#include "tt.hpp" #include "util/types.hpp" #include #include @@ -12,6 +13,8 @@ namespace Clockwork { +class TT; + struct PsqtState; struct PsqtUpdates; @@ -259,15 +262,18 @@ struct Position { } template - [[nodiscard]] Position move(Move m, PsqtState* psqt_state) const; - [[nodiscard]] Position null_move() const; - - [[nodiscard]] Position move(Move m) const { - return move(m, nullptr); + [[nodiscard]] Position move(Move m, PsqtState* psqt_state, const TT* tt = nullptr) const; + [[nodiscard]] Position move(Move m, PsqtState& psqt_state, const TT* tt = nullptr) const { + return move(m, &psqt_state, tt); } [[nodiscard]] Position move(Move m, PsqtState& psqt_state) const { return move(m, &psqt_state); } + [[nodiscard]] Position move(Move m) const { + return move(m, nullptr); + } + + [[nodiscard]] Position null_move() const; [[nodiscard]] std::tuple calc_pin_mask() const; diff --git a/src/search.cpp b/src/search.cpp index 6269d25c..0f2ab238 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -10,8 +10,8 @@ #include "tm.hpp" #include "tuned.hpp" #include "uci.hpp" -#include "util/large_pages.hpp" #include "util/log2.hpp" +#include "util/mem.hpp" #include "util/types.hpp" #include #include @@ -555,7 +555,7 @@ Value Worker::search( for (Move m = moves.next(); m != Move::none(); m = moves.next()) { ss->cont_hist_entry = &m_td.history.get_cont_hist_entry(pos, m); - Position pos_after = pos.move(m, m_td.push_psqt_state()); + Position pos_after = pos.move(m, m_td.push_psqt_state(), &m_searcher.tt); repetition_info.push(pos_after.get_hash_key(), pos_after.is_reversible(m)); Value probcut_value = @@ -708,7 +708,7 @@ Value Worker::search( // Do move ss->cont_hist_entry = &m_td.history.get_cont_hist_entry(pos, m); - Position pos_after = pos.move(m, m_td.push_psqt_state()); + Position pos_after = pos.move(m, m_td.push_psqt_state(), &m_searcher.tt); moves_played++; // Put hash into repetition table. TODO: encapsulate this and any other future adjustment to do "on move" into a proper function @@ -988,7 +988,7 @@ Value Worker::quiesce(const Position& pos, Stack* ss, Value alpha, Value beta, i // Do move ss->cont_hist_entry = &m_td.history.get_cont_hist_entry(pos, m); - Position pos_after = pos.move(m, m_td.push_psqt_state()); + Position pos_after = pos.move(m, m_td.push_psqt_state(), &m_searcher.tt); moves_searched++; // If we've found a legal move, then we can begin skipping quiet moves. diff --git a/src/tt.cpp b/src/tt.cpp index 52368c91..4fee1c47 100644 --- a/src/tt.cpp +++ b/src/tt.cpp @@ -85,6 +85,11 @@ std::optional TT::probe(const Position& pos, i32 ply) const { return {}; } +TTClusterMemory* TT::addr_key(const u64 key) const { + size_t idx = mulhi64(key, m_size); + return &this->m_clusters[idx]; +} + void TT::store(const Position& pos, i32 ply, Value eval, diff --git a/src/tt.hpp b/src/tt.hpp index 5984c59b..8488ebd9 100644 --- a/src/tt.hpp +++ b/src/tt.hpp @@ -1,7 +1,7 @@ #pragma once #include "position.hpp" -#include "util/large_pages.hpp" +#include "util/mem.hpp" #include #include #include @@ -105,6 +105,8 @@ class TT { void clear(); void increment_age(); i32 hashfull() const; + TTClusterMemory* addr_key(const u64 key) const; + private: unique_ptr_huge_page m_clusters; diff --git a/src/util/large_pages.hpp b/src/util/mem.hpp similarity index 97% rename from src/util/large_pages.hpp rename to src/util/mem.hpp index 75d426e2..13508c2a 100644 --- a/src/util/large_pages.hpp +++ b/src/util/mem.hpp @@ -13,6 +13,8 @@ #include #endif +// Large page allocation utilities + template using unique_ptr_huge_page = std::conditional_t, @@ -154,3 +156,9 @@ unique_ptr_huge_page make_unique_for_overwrite_huge_page(std::size_t n) { template requires std::is_bounded_array_v void make_unique_for_overwrite_huge_page(Args&&...) = delete; + + +// Prefetching utilities +inline void prefetch(const void* ptr) { + __builtin_prefetch(ptr); +}