diff --git a/CMakeLists.txt b/CMakeLists.txt index f771042..7f57b0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,7 @@ option(BUILD_TESTS "Build tests" OFF) set(CMAKE_CXX_STANDARD 26) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_EXTENSIONS ON) add_compile_options(-Wall -Wextra -Wpedantic) if(CMAKE_BUILD_TYPE STREQUAL "Release") diff --git a/Chess/Attacks.cpp b/Chess/Attacks.cpp index b57148c..dfdbf99 100644 --- a/Chess/Attacks.cpp +++ b/Chess/Attacks.cpp @@ -1,4 +1,5 @@ #include "Attacks.h" + #include #include "BitBoard.h" @@ -29,8 +30,7 @@ Bitboard SimpleChessEngine::GenerateAttackMask(const BitIndex square, for (auto direction : GetStepDelta()) { Bitboard step; - for (BitIndex temp = square; - (occupancy & SingleSquare(temp)).None();) { + for (BitIndex temp = square; (occupancy & SingleSquare(temp)).None();) { step = DoShiftIfValid(temp, direction).value_or(kEmptyBoard); result |= step; if (step.None()) break; diff --git a/Chess/DebugInfo.h b/Chess/DebugInfo.h index 0a9758a..93a4482 100644 --- a/Chess/DebugInfo.h +++ b/Chess/DebugInfo.h @@ -13,6 +13,9 @@ struct DebugInfo { std::size_t nmp_cuts{}; std::size_t tt_hits{}; + std::size_t tt_pv_misses{}; + std::size_t tt_other_misses{}; + std::size_t tt_wrong_moves{}; std::size_t tt_cuts{}; DebugInfo &operator+=(const DebugInfo &other) { @@ -23,6 +26,9 @@ struct DebugInfo { nmp_tries += other.nmp_tries; nmp_cuts += other.nmp_cuts; tt_hits += other.tt_hits; + tt_pv_misses += other.tt_pv_misses; + tt_other_misses += other.tt_other_misses; + tt_wrong_moves += other.tt_wrong_moves; tt_cuts += other.tt_cuts; return *this; } diff --git a/Chess/Eval.h b/Chess/Eval.h new file mode 100644 index 0000000..8608045 --- /dev/null +++ b/Chess/Eval.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace SimpleChessEngine { + +struct Eval { + std::int16_t value{0}; + + constexpr Eval() = default; + constexpr Eval(std::int16_t v) : value(v) {} + constexpr Eval(int v) : value(static_cast(v)) {} + constexpr Eval(std::size_t v) : value(static_cast(v)) {} + + constexpr auto operator<=>(const Eval&) const = default; + constexpr bool operator==(const Eval&) const = default; + + constexpr Eval operator-() const { return Eval(-value); } + constexpr Eval operator+() const { return *this; } + + constexpr Eval operator+(const Eval& other) const { + return Eval(value + other.value); + } + constexpr Eval operator-(const Eval& other) const { + return Eval(value - other.value); + } + constexpr Eval operator*(const std::integral auto scale) const { + return Eval(value * scale); + } + constexpr Eval operator/(const std::integral auto scale) const { + return Eval(value / scale); + } + + constexpr Eval& operator+=(const Eval& other) { + value += other.value; + return *this; + } + constexpr Eval& operator-=(const Eval& other) { + value -= other.value; + return *this; + } + constexpr Eval& operator*=(const std::integral auto scale) { + value *= scale; + return *this; + } + constexpr Eval& operator/=(const std::integral auto scale) { + value /= scale; + return *this; + } +}; + +inline Eval operator*(const std::integral auto scale, const Eval& eval) { + return eval * scale; +} + +inline std::ostream& operator<<(std::ostream& os, const Eval& eval) { + return os << eval.value; +} + +inline constexpr Eval operator""_ev(unsigned long long value) { + return Eval(static_cast(value)); +} +} // namespace SimpleChessEngine + +namespace std { + template<> + constexpr SimpleChessEngine::Eval numeric_limits::min() noexcept { + return SimpleChessEngine::Eval(numeric_limits::min()); + } + template<> + constexpr SimpleChessEngine::Eval numeric_limits::max() noexcept{ + return SimpleChessEngine::Eval(numeric_limits::max()); + } + constexpr SimpleChessEngine::Eval abs(const SimpleChessEngine::Eval eval) noexcept { + return SimpleChessEngine::Eval(std::abs(eval.value)); + } +} \ No newline at end of file diff --git a/Chess/Evaluation.cpp b/Chess/Evaluation.cpp index a4c638d..076340d 100644 --- a/Chess/Evaluation.cpp +++ b/Chess/Evaluation.cpp @@ -6,10 +6,10 @@ namespace SimpleChessEngine { kPhaseValueLimits[static_cast(GamePhase::kMiddleGame)]; const auto eg_limit = kPhaseValueLimits[static_cast(GamePhase::kEndGame)]; - pv = std::clamp(pv, eg_limit, mg_limit); - return (eval[static_cast(GamePhase::kMiddleGame)] * (pv - eg_limit) + - eval[static_cast(GamePhase::kEndGame)] * (mg_limit - pv)) / - kLimitsDifference; + pv = std::clamp(pv, eg_limit, mg_limit); + return (eval[static_cast(GamePhase::kMiddleGame)].value * (pv - eg_limit).value + + eval[static_cast(GamePhase::kEndGame)].value * (mg_limit - pv).value) / + kLimitsDifference.value; } [[nodiscard]] Eval Position::Evaluate() const { diff --git a/Chess/Evaluation.h b/Chess/Evaluation.h index 994a455..6a5198e 100644 --- a/Chess/Evaluation.h +++ b/Chess/Evaluation.h @@ -3,17 +3,17 @@ #include #include +#include "Eval.h" #include "Utility.h" namespace SimpleChessEngine { -using Eval = int; using SearchResult = std::optional; enum class GamePhase : std::uint8_t { kMiddleGame, kEndGame }; constexpr size_t kGamePhases = 2; -using PhaseValue = int; +using PhaseValue = Eval; struct TaperedEval { std::array eval{}; @@ -68,19 +68,19 @@ constexpr Eval kFullNonPawnMaterial = .eval[static_cast(GamePhase::kMiddleGame)] * 2; -constexpr std::array kPhaseValueLimits = {kFullNonPawnMaterial, 0}; +constexpr std::array kPhaseValueLimits = {kFullNonPawnMaterial, Eval{0}}; constexpr PhaseValue kLimitsDifference = kPhaseValueLimits[0] - kPhaseValueLimits[1]; constexpr Eval kTempoBonus = 20; -constexpr Eval kMateValue = -100'000; +constexpr Eval kMateValue = -10'000; constexpr Eval kDrawValue = 0; // returns zero if the score is not mate value // otherwise returns 1 if it is winning (positive), -1 if losing (negative) -inline int IsMateScore(const int score) { +inline int IsMateScore(const Eval score) { if (-std::abs(score) > kMateValue + static_cast(kMaxSearchPly) + 1) { return 0; } diff --git a/Chess/KillerTable.h b/Chess/KillerTable.h index 7dff7e2..81fe50e 100644 --- a/Chess/KillerTable.h +++ b/Chess/KillerTable.h @@ -8,7 +8,7 @@ template class KillerTable { public: void Clear() { killer_total_count_.fill(0); } - void TryAdd(const Depth ply, const Move& move) { + void TryAdd(const Depth ply, Move move) { if (!Contains(ply, move)) { table_[ply][killer_total_count_[ply]++ % MaxKillerCount] = move; } @@ -16,13 +16,13 @@ class KillerTable { size_t AvailableKillerCount(const Depth ply) const { return std::min(killer_total_count_[ply], MaxKillerCount); } - const Move& Get(const Depth ply, const size_t index) const { + Move Get(const Depth ply, const size_t index) const { assert(index < killer_total_count_[ply] && index < MaxKillerCount); return table_[ply][index]; } private: - bool Contains(const Depth ply, const Move& move) { + bool Contains(const Depth ply, Move move) { for (size_t i = 0; i < std::min(killer_total_count_[ply], MaxKillerCount); ++i) { if (move == table_[ply][i]) return true; diff --git a/Chess/Move.h b/Chess/Move.h index 61a1990..03da381 100644 --- a/Chess/Move.h +++ b/Chess/Move.h @@ -1,7 +1,12 @@ #pragma once -#include #include -#include +#include +#include +#include +#include +#include +#include + #include "BitBoard.h" #include "Piece.h" @@ -10,92 +15,568 @@ namespace SimpleChessEngine { struct NullMove {}; enum class MoveType : std::uint16_t { - kNormal = 0, - kCastling = 1 << 14, - kEnPassant = 2 << 14, - kPromotion = 3 << 14, + kNormal = 0, + kCastling = 1 << 14, + kEnPassant = 2 << 14, + kPromotion = 3 << 14, }; class Move { -private: - static constexpr std::uint16_t kSquareMask = 0x3F; - static constexpr std::uint16_t kPromotionMask = 0x3; - static constexpr std::uint16_t kTypeMask = 0x3; - static constexpr std::uint8_t kFromShift = 6; - static constexpr std::uint8_t kPromotionShift = 12; - static constexpr std::uint8_t kTypeShift = 14; - static constexpr std::uint16_t kNullValue = 65; - static constexpr std::uint16_t kNoneValue = 0; + private: + static constexpr std::uint16_t kSquareMask = 0x3F; + static constexpr std::uint16_t kPromotionMask = 0x3; + static constexpr std::uint16_t kTypeMask = 0x3; + static constexpr std::uint8_t kFromShift = 6; + static constexpr std::uint8_t kPromotionShift = 12; + static constexpr std::uint8_t kTypeShift = 14; + static constexpr std::uint16_t kNullValue = 65; + static constexpr std::uint16_t kNoneValue = 0; + + public: + Move() = default; + constexpr explicit Move(std::uint16_t data) : data_(data) {} + constexpr Move(BitIndex from, BitIndex to) + : data_((from << kFromShift) + to) {} + + template + static constexpr Move Make(BitIndex from, BitIndex to, + Piece promotion_piece = Piece::kKnight) { + return Move(static_cast(T) + + ((static_cast(promotion_piece) - + static_cast(Piece::kKnight)) + << kPromotionShift) + + (from << kFromShift) + to); + } + + constexpr BitIndex From() const { + assert(IsValid()); + return static_cast((data_ >> kFromShift) & kSquareMask); + } + + constexpr BitIndex To() const { + assert(IsValid()); + return static_cast(data_ & kSquareMask); + } + + constexpr MoveType Type() const { + return static_cast(data_ & (kTypeMask << kTypeShift)); + } + + constexpr Piece PromotionPiece() const { + return static_cast(((data_ >> kPromotionShift) & kPromotionMask) + + static_cast(Piece::kKnight)); + } + + constexpr bool IsValid() const { + return data_ != kNoneValue && data_ != kNullValue; + } + + constexpr bool IsPromotion() const { return Type() == MoveType::kPromotion; } + constexpr bool IsEnPassant() const { return Type() == MoveType::kEnPassant; } + constexpr bool IsCastling() const { return Type() == MoveType::kCastling; } + constexpr bool IsNormal() const { return Type() == MoveType::kNormal; } + + template + bool IsQuiet(const Position& position) const { + if (IsPromotion() || IsEnPassant()) return false; + if (IsCastling()) return true; + return position.GetPieceAt(To()) == Piece::kNone; + } + + static constexpr Move Null() { return Move(kNullValue); } + static constexpr Move None() { return Move(kNoneValue); } + + constexpr bool operator==(const Move& other) const = default; + constexpr bool operator!=(const Move& other) const = default; + constexpr explicit operator bool() const { return data_ != 0; } + constexpr std::uint16_t Raw() const { return data_; } + + private: + std::uint16_t data_ = {}; +}; + +enum class CastlingSide : std::uint8_t { k00, k000 }; + +struct PseudoLegalTag {}; +struct LegalTag {}; + +template +concept MoveTag = std::same_as || std::same_as; + +template +concept TagConvertible = + (std::same_as) || + (std::same_as && std::same_as); + +template +struct RefTypeTraits; + +template <> +struct RefTypeTraits { + using ValueType = Move; + using ReferenceType = Move&; + using ConstReferenceType = const Move&; + static constexpr bool is_const = false; + static constexpr bool is_reference = true; +}; + +template <> +struct RefTypeTraits { + using ValueType = Move; + using ReferenceType = const Move&; + using ConstReferenceType = const Move&; + static constexpr bool is_const = true; + static constexpr bool is_reference = true; +}; + +template <> +struct RefTypeTraits { + using ValueType = Move; + using ReferenceType = Move&; + using ConstReferenceType = const Move&; + static constexpr bool is_const = false; + static constexpr bool is_reference = false; +}; + +template + requires MoveTag && + (std::same_as || std::same_as || + std::same_as) +class TypedMove { + private: + using Traits = RefTypeTraits; + using StorageType = std::conditional_t< + Traits::is_reference, + std::conditional_t, Move>; + + StorageType storage_; + + template + requires IsRef + explicit TypedMove(typename Traits::ReferenceType move) : storage_(&move) {} + + template + requires(!IsRef) + explicit TypedMove(Move move) : storage_(move) {} + + template + friend inline TypedMove MakeTypedMove( + typename RefTypeTraits::ReferenceType); + + template + friend inline TypedMove MakeTypedMove(Move); + + template + requires MoveTag + friend void swap(TypedMove&, TypedMove&) noexcept; + + public: + template + requires(!IsRef) + TypedMove() : storage_() {} + + TypedMove(const TypedMove&) = default; + TypedMove& operator=(const TypedMove&) = default; + + TypedMove(TypedMove&&) noexcept = default; + TypedMove& operator=(TypedMove&&) noexcept = default; + + // Conversion from TypedMove to TypedMove + template + requires(std::same_as && std::same_as) + TypedMove(TypedMove& other) : storage_(&other.get()) {} + + // Conversion from TypedMove to TypedMove + template + requires(std::same_as && + (std::same_as || + std::same_as)) + TypedMove(const TypedMove& other) + : storage_(&other.get()) {} + + template + requires TagConvertible && + std::same_as + TypedMove(const TypedMove& other) { + if constexpr (Traits::is_reference) { + storage_ = &other.get(); + } else { + storage_ = other.get(); + } + } + + constexpr typename Traits::ConstReferenceType get() const { + if constexpr (Traits::is_reference) { + return *storage_; + } else { + return storage_; + } + } + + constexpr typename Traits::ReferenceType get() + requires(!Traits::is_const) + { + if constexpr (Traits::is_reference) { + return *storage_; + } else { + return storage_; + } + } + + constexpr operator typename Traits::ConstReferenceType() const& { + return get(); + } + + constexpr operator Move() const&& { + if constexpr (Traits::is_reference) { + return *storage_; + } else { + return storage_; + } + } + + constexpr operator typename Traits::ReferenceType() & + requires(!Traits::is_const) + { + return get(); + } + + // Conversion from reference types to value type + template + requires(Traits::is_reference) + constexpr operator TypedMove() const { + return MakeTypedMove(get()); + } + + constexpr BitIndex From() const { return get().From(); } + constexpr BitIndex To() const { return get().To(); } + constexpr MoveType Type() const { return get().Type(); } + constexpr Piece PromotionPiece() const { return get().PromotionPiece(); } + constexpr bool IsValid() const { return get().IsValid(); } + constexpr bool IsPromotion() const { return get().IsPromotion(); } + constexpr bool IsEnPassant() const { return get().IsEnPassant(); } + constexpr bool IsCastling() const { return get().IsCastling(); } + constexpr bool IsNormal() const { return get().IsNormal(); } + constexpr std::uint16_t Raw() const { return get().Raw(); } + constexpr explicit operator bool() const { return get().operator bool(); } -public: - Move() = default; - constexpr explicit Move(std::uint16_t data) : data_(data) {} - constexpr Move(BitIndex from, BitIndex to) : data_((from << kFromShift) + to) {} + template + bool IsQuiet(const Position& position) const { + return get().IsQuiet(position); + } - template - static constexpr Move Make(BitIndex from, BitIndex to, Piece promotion_piece = Piece::kKnight) { - return Move(static_cast(T) + - ((static_cast(promotion_piece) - static_cast(Piece::kKnight)) << kPromotionShift) + - (from << kFromShift) + to); + constexpr bool operator==(const TypedMove& other) const { + return get() == other.get(); + } + constexpr bool operator!=(const TypedMove& other) const { + return get() != other.get(); + } + + constexpr bool operator==(const Move& other) const { return get() == other; } + constexpr bool operator!=(const Move& other) const { return get() != other; } +}; + +template + requires MoveTag +void swap(TypedMove& lhs, TypedMove& rhs) noexcept { + Move temp = lhs.get(); + lhs = TypedMove(rhs.get()); + rhs = TypedMove(temp); +} + +template +inline TypedMove MakeTypedMove( + typename RefTypeTraits::ReferenceType move) { + static_assert(MoveTag, "Tag must be a valid MoveTag"); + return TypedMove(move); +} + +template +inline TypedMove MakeTypedMove(Move move) { + static_assert(MoveTag, "Tag must be a valid MoveTag"); + return TypedMove(move); +} + +using PseudoLegalMoveRef = TypedMove; +using PseudoLegalMoveConstRef = TypedMove; +using LegalMoveRef = TypedMove; +using LegalMoveConstRef = TypedMove; + +using PseudoLegalMove = TypedMove; +using LegalMove = TypedMove; + +static_assert(std::is_constructible_v, + "LegalMoveRef must be constructible from LegalMove&"); +static_assert(std::is_constructible_v, + "LegalMoveConstRef must be constructible from const LegalMove&"); +static_assert(std::is_constructible_v, + "LegalMoveConstRef must be constructible from LegalMove&"); + +static_assert(std::is_constructible_v, + "PseudoLegalMoveRef must be constructible from PseudoLegalMove&"); +static_assert( + std::is_constructible_v, + "PseudoLegalMoveConstRef must be constructible from const " + "PseudoLegalMove&"); +static_assert( + std::is_constructible_v, + "PseudoLegalMoveConstRef must be constructible from PseudoLegalMove&"); + +template + requires MoveTag +class MoveList { + public: + using ValueType = Move; + using Container = std::vector; + using SizeType = typename Container::size_type; + + using Reference = TypedMove; + using ConstReference = TypedMove; + + template + class Iterator { + private: + using BaseIterator = + std::conditional_t; + + BaseIterator iter_; + + public: + struct ArrowProxy { + std::conditional_t move; + auto operator->() { return &move; } + auto operator->() const { return &move; } + }; + + using iterator_category = std::random_access_iterator_tag; + using value_type = Move; + using difference_type = typename BaseIterator::difference_type; + using pointer = ArrowProxy; + using reference = std::conditional_t; + + Iterator() = default; + explicit Iterator(BaseIterator iter) : iter_(iter) {} + + template + requires(IsConst && !WasConst) + Iterator(const Iterator& other) : iter_(other.iter_) {} + + reference operator*() const { + if constexpr (IsConst) { + return MakeTypedMove(*iter_); + } else { + return MakeTypedMove(*iter_); + } } - constexpr BitIndex From() const { - assert(IsValid()); - return static_cast((data_ >> kFromShift) & kSquareMask); + pointer operator->() const { + if constexpr (IsConst) { + return ArrowProxy{MakeTypedMove(*iter_)}; + } else { + return ArrowProxy{MakeTypedMove(*iter_)}; + } } - constexpr BitIndex To() const { - assert(IsValid()); - return static_cast(data_ & kSquareMask); + Iterator& operator++() { + ++iter_; + return *this; + } + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; } + Iterator& operator--() { + --iter_; + return *this; + } + Iterator operator--(int) { + Iterator tmp = *this; + --iter_; + return tmp; + } + + Iterator& operator+=(difference_type n) { + iter_ += n; + return *this; + } + Iterator& operator-=(difference_type n) { + iter_ -= n; + return *this; + } + + Iterator operator+(difference_type n) const { return Iterator(iter_ + n); } + Iterator operator-(difference_type n) const { return Iterator(iter_ - n); } - constexpr MoveType Type() const { - return static_cast(data_ & (kTypeMask << kTypeShift)); + difference_type operator-(const Iterator& other) const { + return iter_ - other.iter_; } - constexpr Piece PromotionPiece() const { - return static_cast(((data_ >> kPromotionShift) & kPromotionMask) + static_cast(Piece::kKnight)); + reference operator[](difference_type n) const { + if constexpr (IsConst) { + return MakeTypedMove(iter_[n]); + } else { + return MakeTypedMove(iter_[n]); + } } - constexpr bool IsValid() const { - return data_ != kNoneValue && data_ != kNullValue; + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } + bool operator!=(const Iterator& other) const { + return iter_ != other.iter_; + } + bool operator<(const Iterator& other) const { return iter_ < other.iter_; } + bool operator<=(const Iterator& other) const { + return iter_ <= other.iter_; + } + bool operator>(const Iterator& other) const { return iter_ > other.iter_; } + bool operator>=(const Iterator& other) const { + return iter_ >= other.iter_; } - constexpr bool IsPromotion() const { return Type() == MoveType::kPromotion; } - constexpr bool IsEnPassant() const { return Type() == MoveType::kEnPassant; } - constexpr bool IsCastling() const { return Type() == MoveType::kCastling; } - constexpr bool IsNormal() const { return Type() == MoveType::kNormal; } - - template - bool IsQuiet(const Position& position) const { - if (IsPromotion() || IsEnPassant()) return false; - if (IsCastling()) return true; - return position.GetPieceAt(To()) == Piece::kNone; + template + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } + template + bool operator!=(const Iterator& other) const { + return iter_ != other.iter_; } - static constexpr Move Null() { return Move(kNullValue); } - static constexpr Move None() { return Move(kNoneValue); } + template + friend class Iterator; + }; - constexpr bool operator==(const Move& other) const = default; - constexpr bool operator!=(const Move& other) const = default; - constexpr explicit operator bool() const { return data_ != 0; } - constexpr std::uint16_t Raw() const { return data_; } + using iterator = Iterator; + using const_iterator = Iterator; - struct Hash { - std::size_t operator()(const Move& move) const { - return move.data_ * 6364136223846793005ULL + 1442695040888963407ULL; - } - }; + MoveList() = default; + explicit MoveList(Container&& moves) : moves_(std::move(moves)) {} + explicit MoveList(const Container& moves) : moves_(moves) {} -private: - std::uint16_t data_; -}; + [[nodiscard]] bool empty() const noexcept { return moves_.empty(); } + [[nodiscard]] SizeType size() const noexcept { return moves_.size(); } + [[nodiscard]] SizeType capacity() const noexcept { return moves_.capacity(); } + void reserve(SizeType n) { moves_.reserve(n); } + void clear() noexcept { moves_.clear(); } -enum class CastlingSide : std::uint8_t { k00, k000 }; + Reference operator[](SizeType pos) { + return MakeTypedMove(moves_[pos]); + } + + ConstReference operator[](SizeType pos) const { + return MakeTypedMove(moves_[pos]); + } + + Reference front() { return MakeTypedMove(moves_.front()); } + ConstReference front() const { + return MakeTypedMove(moves_.front()); + } + + Reference back() { return MakeTypedMove(moves_.back()); } + ConstReference back() const { + return MakeTypedMove(moves_.back()); + } + + iterator begin() noexcept { return iterator(moves_.begin()); } + const_iterator begin() const noexcept { + return const_iterator(moves_.begin()); + } + const_iterator cbegin() const noexcept { + return const_iterator(moves_.begin()); + } + + iterator end() noexcept { return iterator(moves_.end()); } + const_iterator end() const noexcept { return const_iterator(moves_.end()); } + const_iterator cend() const noexcept { return const_iterator(moves_.end()); } + + void push_back(TypedMove move) { moves_.push_back(move.get()); } + + template + Reference emplace_back(Args&&... args) { + moves_.emplace_back(std::forward(args)...); + return MakeTypedMove(moves_.back()); + } + + void pop_back() { moves_.pop_back(); } + + [[nodiscard]] Container release() noexcept { return std::move(moves_); } + + [[nodiscard]] const Container& data() const noexcept { return moves_; } + + private: + Container moves_; +}; static_assert(sizeof(Move) == 2, "Move must be exactly 2 bytes"); static_assert(alignof(Move) == 2, "Move should be 2-byte aligned"); -static_assert(std::is_trivially_copyable_v, "Move must be trivially copyable"); -static_assert(std::is_standard_layout_v, "Move must have standard layout"); +static_assert(std::is_trivially_copyable_v, + "Move must be trivially copyable"); +static_assert(std::is_standard_layout_v, + "Move must have standard layout"); + +static_assert(sizeof(PseudoLegalMove) == sizeof(Move), + "PseudoLegalMove must have same size as Move"); +static_assert(sizeof(LegalMove) == sizeof(Move), + "LegalMove must have same size as Move"); +static_assert(sizeof(PseudoLegalMoveRef) == sizeof(Move*), + "PseudoLegalMoveRef must be pointer-sized"); +static_assert(sizeof(LegalMoveRef) == sizeof(Move*), + "LegalMoveRef must be pointer-sized"); +static_assert(std::is_convertible_v, + "LegalMove must be convertible to PseudoLegalMove"); +static_assert(std::is_convertible_v, + "LegalMove must be convertible to Move"); +static_assert(std::is_convertible_v, + "PseudoLegalMove must be convertible to Move"); + +static_assert(!std::is_convertible_v, + "Move must not be convertible to PseudoLegalMove"); +static_assert(!std::is_convertible_v, + "PseudoLegalMove must not be convertible to LegalMove"); +static_assert(!std::is_convertible_v, + "Move must not be convertible to LegalMove"); + +static_assert(std::is_convertible_v, + "LegalMoveRef must be convertible to PseudoLegalMoveRef"); +static_assert(!std::is_convertible_v, + "PseudoLegalMoveRef must not be convertible to LegalMoveRef"); + +class Position; + +template +To UnsafeMoveCast(From move) { + static_assert( + std::is_same_v || std::is_same_v, + "To must be PseudoLegalMove or LegalMove"); + static_assert( + std::is_same_v || std::is_same_v, + "From must be Move or PseudoLegalMove"); + + if constexpr (std::is_same_v && + std::is_same_v) { + return MakeTypedMove(move); + } else if constexpr (std::is_same_v && + std::is_same_v) { + return MakeTypedMove(static_cast(move)); + } else if constexpr (std::is_same_v && + std::is_same_v) { + return MakeTypedMove(move); + } else { + return move; + } +} } // namespace SimpleChessEngine + +namespace std { +template + requires SimpleChessEngine::MoveTag +void swap( + SimpleChessEngine::TypedMove lhs, + SimpleChessEngine::TypedMove rhs) noexcept { + std::swap(lhs.get(), rhs.get()); +} +} // namespace std diff --git a/Chess/MoveGenerator.cpp b/Chess/MoveGenerator.cpp index 23fc5e5..01bdacc 100644 --- a/Chess/MoveGenerator.cpp +++ b/Chess/MoveGenerator.cpp @@ -6,352 +6,382 @@ #include "BitBoard.h" namespace SimpleChessEngine { -template MoveGenerator::Moves MoveGenerator::GenerateMoves< - MoveGenerator::Type::kAll>(Position& position) const; -template MoveGenerator::Moves MoveGenerator::GenerateMoves< - MoveGenerator::Type::kQuiescence>(Position& position) const; -template MoveGenerator::Moves MoveGenerator::GenerateMoves< - MoveGenerator::Type::kAddChecks>(Position& position) const; - -constexpr MoveGenerator::Type operator|(MoveGenerator::Type a, - MoveGenerator::Type b) { - return static_cast(static_cast(a) | - static_cast(b)); -} - -constexpr MoveGenerator::Type operator&(MoveGenerator::Type a, - MoveGenerator::Type b) { - return static_cast(static_cast(a) & - static_cast(b)); -} -constexpr bool operator!(MoveGenerator::Type t) { - return static_cast(t) == 0; -} +template MoveList MoveGenerator::GenerateMoves< + MoveGenerator::Type::kCaptures>(const Position& position) const; +template MoveList MoveGenerator::GenerateMoves< + MoveGenerator::Type::kQuiets>(const Position& position) const; +template MoveList MoveGenerator::GenerateMoves< + MoveGenerator::Type::kEvasions>(const Position& position) const; +template MoveList MoveGenerator::GenerateMoves< + MoveGenerator::Type::kNonEvasions>(const Position& position) const; +template MoveList MoveGenerator::GenerateMoves< + MoveGenerator::Type::kLegal>(const Position& position) const; MoveGenerator::~MoveGenerator() = default; -[[nodiscard]] bool MoveGenerator::IsPawnMoveLegal(Position& position, - const Move& move) { - const auto us = position.GetSideToMove(); - - if (move.IsEnPassant()) { - const auto irreversible_data = position.GetIrreversibleData(); - position.DoMove(move); - const auto valid = !position.IsUnderCheck(us); - position.UndoMove(move, irreversible_data); - return valid; - } - - const auto from = move.From(); - const auto to = move.To(); - - return !position.GetIrreversibleData().blockers[static_cast(us)].Test( - from) || - Ray(position.GetKingSquare(us), from).Test(to); -} - -void MoveGenerator::GenerateCastling(Moves& moves, const Position& position) { - if (position.IsUnderCheck()) { - return; - } +template +void MoveGenerator::GeneratePawnMoves(Moves& moves, const Position& position, + Bitboard target) const { + constexpr Player Them = Flip(Us); + constexpr Bitboard promotion_rank = + (Us == Player::kWhite ? kRankBB[6] : kRankBB[1]); + constexpr Bitboard third_rank = + (Us == Player::kWhite ? kRankBB[2] : kRankBB[5]); + constexpr Compass up = kPawnMoveDirection[static_cast(Us)]; + constexpr Compass down = kPawnMoveDirection[static_cast(Them)]; + + const auto pawns = position.GetPiecesByType(Us); + const auto empty_squares = ~position.GetAllPieces(); + const auto enemies = GenType == Type::kEvasions + ? (position.Attackers(position.GetKingSquare(Us)) & + position.GetPieces(Them)) + : position.GetPieces(Them); + + const auto pawns_on_7 = pawns & promotion_rank; + const auto pawns_not_on_7 = pawns & ~promotion_rank; + + if constexpr (GenType != Type::kCaptures) { + auto push = Shift(pawns_not_on_7, up) & empty_squares; + auto double_push = Shift(push & third_rank, up) & empty_squares; + + if constexpr (GenType == Type::kEvasions) { + push &= target; + double_push &= target; + } - const auto side_to_move = position.GetSideToMove(); - const auto king_square = position.GetKingSquare(side_to_move); - const auto color_idx = static_cast(side_to_move); + while (push.Any()) { + const auto to = push.PopFirstBit(); + const auto from = Shift(to, down); + moves.emplace_back(Move(from, to)); + } - for (const auto castling_side : - {CastlingSide::k00, CastlingSide::k000}) { - if (position.CanCastle(castling_side)) { - const auto side_idx = static_cast(castling_side); - const auto king_to = kKingCastlingDestination[color_idx][side_idx]; - - moves.emplace_back(Move::Make(king_square, king_to)); + while (double_push.Any()) { + const auto to = double_push.PopFirstBit(); + const auto from = Shift(Shift(to, down), down); + moves.emplace_back(Move(from, to)); } } -} -template -void MoveGenerator::GenerateMovesForPiece(Moves& moves, Position& position, - const Bitboard target) const { - static_assert(piece != Piece::kPawn && piece != Piece::kKing); + if (pawns_on_7.Any()) { + constexpr std::array attack_dirs = { + (Us == Player::kWhite ? Compass::kNorthWest : Compass::kSouthWest), + (Us == Player::kWhite ? Compass::kNorthEast : Compass::kSouthEast)}; + constexpr std::array opposite_dirs = { + (Us == Player::kWhite ? Compass::kSouthEast : Compass::kNorthEast), + (Us == Player::kWhite ? Compass::kSouthWest : Compass::kNorthWest)}; + + auto promo_push = Shift(pawns_on_7, up) & empty_squares; + if constexpr (GenType == Type::kEvasions) { + promo_push &= target; + } - const auto us = position.GetSideToMove(); - Bitboard pieces = position.GetPiecesByType(us); + while (promo_push.Any()) { + const auto to = promo_push.PopFirstBit(); + const auto from = Shift(to, down); + if constexpr (GenType == Type::kCaptures || GenType == Type::kEvasions || + GenType == Type::kNonEvasions) { + moves.emplace_back( + Move::Make(from, to, Piece::kQueen)); + } + if constexpr (GenType == Type::kQuiets || GenType == Type::kEvasions || + GenType == Type::kNonEvasions) { + moves.emplace_back( + Move::Make(from, to, Piece::kRook)); + moves.emplace_back( + Move::Make(from, to, Piece::kBishop)); + moves.emplace_back( + Move::Make(from, to, Piece::kKnight)); + } + } - while (pieces.Any()) { - const auto from = pieces.PopFirstBit(); - GenerateMovesFromSquare(moves, position, from, target); + for (size_t i = 0; i < 2; ++i) { + auto promo_capture = + Shift(pawns_on_7 & ~(i == 0 ? kFileBB[0] : kFileBB[7]), + attack_dirs[i]) & + enemies; + while (promo_capture.Any()) { + const auto to = promo_capture.PopFirstBit(); + const auto from = Shift(to, opposite_dirs[i]); + if constexpr (GenType == Type::kCaptures || + GenType == Type::kEvasions || + GenType == Type::kNonEvasions) { + moves.emplace_back( + Move::Make(from, to, Piece::kQueen)); + } + if constexpr (GenType == Type::kQuiets || GenType == Type::kEvasions || + GenType == Type::kNonEvasions) { + moves.emplace_back( + Move::Make(from, to, Piece::kRook)); + moves.emplace_back( + Move::Make(from, to, Piece::kBishop)); + moves.emplace_back( + Move::Make(from, to, Piece::kKnight)); + } + } + } } -} - -template <> -void MoveGenerator::GenerateMovesForPiece( - Moves& moves, Position& position, const Bitboard target) const { - const auto us = position.GetSideToMove(); - const auto us_idx = static_cast(us); - const auto them = Flip(us); - const auto them_idx = static_cast(them); - - const auto pawns = position.GetPiecesByType(us); - const auto promotion_rank = us == Player::kWhite ? kRankBB[6] : kRankBB[1]; - const auto direction = kPawnMoveDirection[us_idx]; - const auto opposite_direction = kPawnMoveDirection[them_idx]; - const auto non_promoting_pawns = pawns & ~promotion_rank; - - const auto valid_squares = ~position.GetAllPieces(); - - const auto third_rank = us == Player::kWhite ? kRankBB[2] : kRankBB[5]; - - auto push = Shift(non_promoting_pawns, direction) & valid_squares; - - const auto double_push_pawns = push & third_rank; - - push &= target; - while (push.Any()) { - const auto to = push.PopFirstBit(); - const auto from = Shift(to, opposite_direction); - moves.emplace_back(Move(from, to)); - } + if constexpr (GenType == Type::kCaptures || GenType == Type::kEvasions || + GenType == Type::kNonEvasions) { + constexpr std::array attack_dirs = { + (Us == Player::kWhite ? Compass::kNorthWest : Compass::kSouthWest), + (Us == Player::kWhite ? Compass::kNorthEast : Compass::kSouthEast)}; + constexpr std::array opposite_dirs = { + (Us == Player::kWhite ? Compass::kSouthEast : Compass::kNorthEast), + (Us == Player::kWhite ? Compass::kSouthWest : Compass::kNorthWest)}; + + for (size_t i = 0; i < 2; ++i) { + auto captures = + Shift(pawns_not_on_7 & ~(i == 0 ? kFileBB[0] : kFileBB[7]), + attack_dirs[i]) & + enemies; + while (captures.Any()) { + const auto to = captures.PopFirstBit(); + const auto from = Shift(to, opposite_dirs[i]); + moves.emplace_back(Move(from, to)); + } + } - auto double_push = Shift(double_push_pawns, direction) & valid_squares; + if (const auto ep_square = position.GetEnCroissantSquare(); + ep_square.has_value()) { + if constexpr (GenType == Type::kEvasions) { + if ((target & SingleSquare(Shift(ep_square.value(), down))).None()) { + return; + } + } - double_push &= target; - while (double_push.Any()) { - const auto to = double_push.PopFirstBit(); - const auto from = Shift(Shift(to, opposite_direction), opposite_direction); - moves.emplace_back(Move(from, to)); + const auto ep_bb = SingleSquare(ep_square.value()); + for (size_t i = 0; i < 2; ++i) { + auto ep_capture = + Shift(pawns_not_on_7 & ~(i == 0 ? kFileBB[0] : kFileBB[7]), + attack_dirs[i]) & + ep_bb; + if (ep_capture.Any()) { + const auto to = ep_square.value(); + const auto from = Shift(to, opposite_dirs[i]); + moves.emplace_back(Move::Make(from, to)); + } + } + } } +} - static constexpr std::array cant_attack_files = {kFileBB[0], kFileBB[7]}; - - const auto attacks = - (us == Player::kWhite) - ? std::array{Compass::kNorthWest, Compass::kNorthEast} - : std::array{Compass::kSouthWest, Compass::kSouthEast}; - - const auto opposite_attacks = - (us == Player::kWhite) - ? std::array{Compass::kSouthEast, Compass::kSouthWest} - : std::array{Compass::kNorthEast, Compass::kNorthWest}; - - const auto enemy_pieces = position.GetPieces(them); - - const auto en_croissant_square = position.GetEnCroissantSquare(); +template +void MoveGenerator::GeneratePieceMoves(Moves& moves, const Position& position, + Bitboard target) const { + static_assert(Pt != Piece::kKing && Pt != Piece::kPawn); - const std::array attacks_to = { - Shift(non_promoting_pawns & ~cant_attack_files.front(), attacks.front()), - Shift(non_promoting_pawns & ~cant_attack_files.back(), attacks.back())}; + auto pieces = position.GetPiecesByType(Us); + while (pieces.Any()) { + const auto from = pieces.PopFirstBit(); + auto attacks = + AttackTable::GetAttackMap(from, position.GetAllPieces()) & target; - for (size_t attack_direction = 0; attack_direction < attacks.size(); - ++attack_direction) { - auto attack_squares = attacks_to[attack_direction] & target & enemy_pieces; + if (position.GetIrreversibleData().blockers[static_cast(Us)].Test( + from)) { + attacks &= Ray(position.GetKingSquare(Us), from); + } - while (attack_squares.Any()) { - const auto to = attack_squares.PopFirstBit(); - const auto from = Shift(to, opposite_attacks[attack_direction]); + while (attacks.Any()) { + const auto to = attacks.PopFirstBit(); moves.emplace_back(Move(from, to)); } } +} - if (en_croissant_square) { - const auto en_croissant_bitboard = - SingleSquare(en_croissant_square.value()); - for (size_t attack_direction = 0; attack_direction < attacks.size(); - ++attack_direction) { - auto attack_to = attacks_to[attack_direction] & en_croissant_bitboard; - if (attack_to.Any()) { - const auto to = en_croissant_square.value(); - const auto from = Shift(to, opposite_attacks[attack_direction]); - moves.emplace_back(Move::Make(from, to)); - } - } +void MoveGenerator::GenerateCastling(Moves& moves, const Position& position, + Player us) const { + if (position.IsUnderCheck(us)) { + return; } - const auto promoting_pawns = pawns & promotion_rank; - - auto promotion_push = - Shift(promoting_pawns, direction) & valid_squares & target; - - while (promotion_push.Any()) { - const auto to = promotion_push.PopFirstBit(); - const auto from = Shift(to, opposite_direction); - - moves.emplace_back(Move::Make(from, to, Piece::kQueen)); - moves.emplace_back(Move::Make(from, to, Piece::kKnight)); - moves.emplace_back(Move::Make(from, to, Piece::kRook)); - moves.emplace_back(Move::Make(from, to, Piece::kBishop)); - } + const auto king_square = position.GetKingSquare(us); + const auto color_idx = static_cast(us); - for (size_t attack_direction = 0; attack_direction < attacks.size(); - ++attack_direction) { - auto attack_squares = - Shift(promoting_pawns & ~cant_attack_files[attack_direction], - attacks[attack_direction]) & - target & enemy_pieces; - - while (attack_squares.Any()) { - const auto to = attack_squares.PopFirstBit(); - const auto from = Shift(to, opposite_attacks[attack_direction]); - - moves.emplace_back(Move::Make(from, to, Piece::kKnight)); - moves.emplace_back(Move::Make(from, to, Piece::kBishop)); - moves.emplace_back(Move::Make(from, to, Piece::kRook)); - moves.emplace_back(Move::Make(from, to, Piece::kQueen)); + for (const auto castling_side : {CastlingSide::k00, CastlingSide::k000}) { + if (position.CanCastle(castling_side)) { + const auto side_idx = static_cast(castling_side); + const auto king_to = kKingCastlingDestination[color_idx][side_idx]; + moves.emplace_back(Move::Make(king_square, king_to)); } } } -template -MoveGenerator::Moves MoveGenerator::GenerateMoves(Position& position) const { - moves_.clear(); +template +void MoveGenerator::GenerateAll(Moves& moves, const Position& position) const { + const auto king_square = position.GetKingSquare(Us); + const auto them = Flip(Us); - const auto us = position.GetSideToMove(); - const auto them = Flip(us); - - auto target = ~position.GetPieces(us); - - if constexpr (!!(type & Type::kQuiescence)) { - target &= position.GetPieces(Flip(us)); - } + Bitboard target; - const auto king_square = position.GetKingSquare(us); - const auto king_attacker = + const auto checkers = position.Attackers(king_square) & position.GetPieces(them); - // Double-check check - if (king_attacker.MoreThanOne()) { - GenerateMovesForPiece(moves_, position, target); - return moves_; - } - - // compute pins - position.ComputePins(us); - - const auto king_target = target; - auto pawn_target = target; - - if constexpr (!!(type & Type::kQuiescence)) { - pawn_target |= (kRankBB[0] | kRankBB[7]); - } - if constexpr (!!(type & Type::kAddChecks)) { - pawn_target |= GetPawnAttacks(king_square, us); - } - // is in check - if (king_attacker.Any()) { - const auto attacker = king_attacker.GetFirstBit(); - const auto ray = - Between(king_square, attacker) | SingleSquare(attacker); - target &= ray; - pawn_target &= ray; + if constexpr (GenType != Type::kEvasions) { + if (checkers.Any()) { + return; + } } - GenerateMovesForPiece(moves_, position, - pawn_target & ~position.GetPieces(us)); + if (checkers.MoreThanOne()) { + target = GenType == Type::kEvasions ? ~position.GetPieces(Us) : target; + auto king_attacks = AttackTable::GetAttackMap( + king_square, position.GetAllPieces()) & + target; - std::erase_if(moves_, [&position](const Move& move) { - return !IsPawnMoveLegal(position, move); - }); + const auto occupancy = position.GetAllPieces() ^ SingleSquare(king_square); + king_attacks &= ~position.GetAllPawnAttacks(them); - // generate moves for piece - auto generate_move_for_piece = [this, &position, king_square, - us](Bitboard target) { - if constexpr (!!(type & Type::kAddChecks)) { - target |= - AttackTable::GetAttackMap(king_square, position.GetPieces(us)); + auto attackers = position.GetPiecesByType(them); + while (attackers.Any()) { + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); } - GenerateMovesForPiece(moves_, position, - target & ~position.GetPieces(us)); - }; - auto generate_moves = [target, &generate_move_for_piece]() { - (generate_move_for_piece.template operator()(target), ...); - }; + attackers = position.GetPiecesByType(them); + while (attackers.Any()) { + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); + } - generate_moves.template - operator()(); + attackers = position.GetPiecesByType(them); + while (attackers.Any()) { + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); + } - GenerateMovesForPiece(moves_, position, - king_target & ~position.GetPieces(us)); - GenerateCastling(moves_, position); + attackers = position.GetPiecesByType(them); + while (attackers.Any()) { + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); + } - // return moves - return moves_; -} + king_attacks &= ~AttackTable::GetAttackMap( + position.GetKingSquare(them), occupancy); -template <> -void MoveGenerator::GenerateMovesForPiece(Moves& moves, - Position& position, - Bitboard target) const { - const auto us = position.GetSideToMove(); - const auto them = Flip(us); + while (king_attacks.Any()) { + const auto to = king_attacks.PopFirstBit(); + moves.emplace_back(Move(king_square, to)); + } + return; + } - const auto king_pos = position.GetKingSquare(us); - const auto king_mask = SingleSquare(king_pos); + if constexpr (GenType == Type::kEvasions) { + target = Between(king_square, checkers.GetFirstBit()) | checkers; + } else if constexpr (GenType == Type::kNonEvasions) { + target = ~position.GetPieces(Us); + } else if constexpr (GenType == Type::kCaptures) { + target = position.GetPieces(them); + } else { + target = ~position.GetAllPieces(); + } - const auto occupancy = position.GetAllPieces() ^ king_mask; + GeneratePawnMoves(moves, position, target); + GeneratePieceMoves(moves, position, target); + GeneratePieceMoves(moves, position, target); + GeneratePieceMoves(moves, position, target); + GeneratePieceMoves(moves, position, target); - // we prevent the king from going to squares attacked by enemy pieces + auto king_target = + GenType == Type::kEvasions ? ~position.GetPieces(Us) : target; + auto king_attacks = AttackTable::GetAttackMap( + king_square, position.GetAllPieces()) & + king_target; - target &= ~position.GetAllPawnAttacks(Flip(us)); + const auto occupancy = position.GetAllPieces() ^ SingleSquare(king_square); + king_attacks &= ~position.GetAllPawnAttacks(them); - Bitboard attackers = position.GetPiecesByType(them); + auto attackers = position.GetPiecesByType(them); while (attackers.Any()) { - target &= ~AttackTable::GetAttackMap( + king_attacks &= ~AttackTable::GetAttackMap( attackers.PopFirstBit(), occupancy); } attackers = position.GetPiecesByType(them); while (attackers.Any()) { - target &= ~AttackTable::GetAttackMap( + king_attacks &= ~AttackTable::GetAttackMap( attackers.PopFirstBit(), occupancy); } attackers = position.GetPiecesByType(them); while (attackers.Any()) { - target &= ~AttackTable::GetAttackMap(attackers.PopFirstBit(), - occupancy); + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); } attackers = position.GetPiecesByType(them); while (attackers.Any()) { - target &= ~AttackTable::GetAttackMap(attackers.PopFirstBit(), - occupancy); + king_attacks &= ~AttackTable::GetAttackMap( + attackers.PopFirstBit(), occupancy); } - target &= ~AttackTable::GetAttackMap( + king_attacks &= ~AttackTable::GetAttackMap( position.GetKingSquare(them), occupancy); - GenerateMovesFromSquare(moves, position, king_pos, target); + while (king_attacks.Any()) { + const auto to = king_attacks.PopFirstBit(); + moves.emplace_back(Move(king_square, to)); + } + + if constexpr (GenType == Type::kQuiets || GenType == Type::kNonEvasions) { + GenerateCastling(moves, position, Us); + } } -template -void MoveGenerator::GenerateMovesFromSquare(Moves& moves, Position& position, - const BitIndex from, - Bitboard target) const { - assert(position.GetPieceAt(from) == piece); +template +auto MoveGenerator::GenerateMoves(const Position& position) const + -> std::conditional_t, + MoveList> { + moves_.clear(); - // get all squares that piece attacks - const auto attacks = - AttackTable::GetAttackMap(from, position.GetAllPieces()); + const auto us = position.GetSideToMove(); + position.ComputePins(us); // TODO: OPTIMIZE, dont need to call it always - // get whose move is now - const auto side_to_move = position.GetSideToMove(); + if constexpr (type == Type::kLegal) { + const auto pinned = + position.GetIrreversibleData().blockers[static_cast(us)] & + position.GetPieces(us); + const auto king_square = position.GetKingSquare(us); - // if the piece is pinned we can only move in pin direction - if (position.GetIrreversibleData() - .blockers[static_cast(side_to_move)] - .Test(from)) { - target &= Ray(position.GetKingSquare(side_to_move), from); - } + const size_t initial_size = moves_.size(); - // we move only in target squares - auto valid_moves = attacks & target; + if (position.IsUnderCheck(us)) { + if (us == Player::kWhite) { + GenerateAll(moves_, position); + } else { + GenerateAll(moves_, position); + } + } else { + if (us == Player::kWhite) { + GenerateAll(moves_, position); + } else { + GenerateAll(moves_, position); + } + } - while (valid_moves.Any()) { - const auto to = valid_moves.PopFirstBit(); - moves.emplace_back(Move(from, to)); + auto it = moves_.begin() + initial_size; + while (it != moves_.end()) { + if (((pinned.Test(it->From())) || it->From() == king_square || + it->IsEnPassant()) && + !position.Legal(*it)) { + *it = moves_.back(); + moves_.pop_back(); + } else { + ++it; + } + } + + return MoveList(std::move(moves_)); + } else { + if (us == Player::kWhite) { + GenerateAll(moves_, position); + } else { + GenerateAll(moves_, position); + } + + return MoveList(std::move(moves_)); } } + } // namespace SimpleChessEngine diff --git a/Chess/MoveGenerator.h b/Chess/MoveGenerator.h index 33c6a2e..68a0d08 100644 --- a/Chess/MoveGenerator.h +++ b/Chess/MoveGenerator.h @@ -14,59 +14,39 @@ namespace SimpleChessEngine { */ class MoveGenerator { public: - enum class Type : std::uint8_t { kAll = 0, kQuiescence = 1, kAddChecks = 2 }; + enum class Type : std::uint8_t { + kCaptures, + kQuiets, + kEvasions, + kNonEvasions, + kLegal + }; static constexpr size_t kMaxMovesPerPosition = 218; MoveGenerator() { moves_.reserve(kMaxMovesPerPosition); } ~MoveGenerator(); - using Moves = std::vector; - - /** - * \brief Generates all possible moves for a given position. - * - * \param position The position. - * - * \return All possible moves for the given position. - */ template - [[nodiscard]] Moves GenerateMoves(Position& position) const; + [[nodiscard]] auto GenerateMoves(const Position& position) const + -> std::conditional_t, + MoveList>; private: - [[nodiscard]] static bool IsPawnMoveLegal(Position& position, - const Move& move); + using Moves = std::vector; + template + void GenerateAll(Moves& moves, const Position& position) const; - /** - * \brief Generates all possible moves for a given square. - * - * \param position The position. - * \param moves Container where to add possible moves. - * \param target Target squares. - * - * \return All possible moves for the given square. - */ - template - void GenerateMovesForPiece(Moves& moves, Position& position, - Bitboard target) const; + template + void GeneratePawnMoves(Moves& moves, const Position& position, + Bitboard target) const; - /** - * \brief Generates all possible moves for a given square with given piece. - * - * \tparam piece The piece. - * \param moves Container where to add moves. - * \param position The position. - * \param from The square. - * \param target Target squares. - * - * \return All possible moves for the given square and piece. - */ - template - void GenerateMovesFromSquare(Moves& moves, Position& position, BitIndex from, - Bitboard target) const; + template + void GeneratePieceMoves(Moves& moves, const Position& position, + Bitboard target) const; - static void GenerateCastling(Moves& moves, const Position& position); + void GenerateCastling(Moves& moves, const Position& position, + Player us) const; - // mutable because it is used as a pre-allocated storage to return, not as an object state mutable Moves moves_; }; diff --git a/Chess/MovePicker.cpp b/Chess/MovePicker.cpp index cb3865c..442d5b8 100644 --- a/Chess/MovePicker.cpp +++ b/Chess/MovePicker.cpp @@ -47,7 +47,7 @@ MovePicker::Stage& operator++(MovePicker::Stage& stage) { return stage; } -MoveGenerator::Moves::const_iterator MovePicker::SelectNextMove( +MovePicker::Moves::const_iterator MovePicker::SelectNextMove( const Searcher& searcher, const Depth ply) { const auto& position = searcher.GetPosition(); const auto compare_captures = [this, &position](size_t lhs, size_t rhs) { @@ -128,11 +128,10 @@ MoveGenerator::Moves::const_iterator MovePicker::SelectNextMove( } assert(false); - std::unreachable(); + __builtin_unreachable(); } -void MovePicker::InitPicker(MoveGenerator::Moves&& moves, - const Searcher& searcher) { +void MovePicker::InitPicker(Moves&& moves, const Searcher& searcher) { moves_ = std::move(moves); data_.resize(moves_.size()); history_.resize(moves_.size()); @@ -140,7 +139,7 @@ void MovePicker::InitPicker(MoveGenerator::Moves&& moves, const auto& move = moves_[i]; const auto from = move.From(); const auto to = move.To(); - + // Determine captured piece Piece capture = Piece::kNone; if (move.IsEnPassant()) { @@ -148,7 +147,7 @@ void MovePicker::InitPicker(MoveGenerator::Moves&& moves, } else if (!move.IsCastling()) { capture = searcher.GetPosition().GetPieceAt(to); } - + data_[i] = {from, to, capture, capture != Piece::kNone ? searcher.GetPosition().StaticExchangeEvaluation( @@ -160,10 +159,14 @@ void MovePicker::InitPicker(MoveGenerator::Moves&& moves, current_move_ = moves_.begin(); } -void MovePicker::SkipMove(const Move& move) { - Swap(current_move_ - moves_.begin(), - std::find(current_move_, moves_.end(), move) - moves_.begin()); +bool MovePicker::SkipMove(Move move) { + const auto it = std::find(current_move_, moves_.end(), move); + if (it == moves_.end()) { + return false; + } + Swap(current_move_ - moves_.begin(), it - moves_.begin()); ++current_move_; + return true; } bool MovePicker::Done() const { return current_move_ == moves_.end(); } diff --git a/Chess/MovePicker.h b/Chess/MovePicker.h index 3c5e41c..fa16301 100644 --- a/Chess/MovePicker.h +++ b/Chess/MovePicker.h @@ -4,7 +4,7 @@ #include #include -#include "Evaluation.h" +#include "Eval.h" #include "Move.h" #include "Piece.h" #include "Utility.h" @@ -14,7 +14,7 @@ class Searcher; class MovePicker { public: - using Moves = std::vector; + using Moves = MoveList; enum class Stage : std::uint8_t { kGoodCaptures, kKillers, @@ -24,8 +24,6 @@ class MovePicker { }; MovePicker(); - MovePicker(const Move&) = delete; - MovePicker(Move&&) = delete; void InitPicker(Moves&& moves, const Searcher& searcher); ~MovePicker() = default; @@ -35,14 +33,16 @@ class MovePicker { Moves::const_iterator SelectNextMove(const Searcher& searcher, const Depth ply); - void SkipMove(const Move& move); + [[nodiscard]] bool SkipMove(Move move); [[nodiscard]] bool Done() const; [[nodiscard]] Stage GetCurrentStage() const; [[nodiscard]] Moves::const_iterator begin() const { return moves_.begin(); } - [[nodiscard]] Moves::const_iterator begin_quiet() const { return begin_quiet_; } + [[nodiscard]] Moves::const_iterator begin_quiet() const { + return begin_quiet_; + } [[nodiscard]] Moves::const_iterator current() const { return current_move_; } [[nodiscard]] Moves::const_iterator end() const { return moves_.end(); } diff --git a/Chess/Perft.cpp b/Chess/Perft.cpp index 5c95aa5..ae615d8 100644 --- a/Chess/Perft.cpp +++ b/Chess/Perft.cpp @@ -13,7 +13,7 @@ size_t Perft(std::ostream& o_stream, Position& position, const Depth depth) { static auto move_generator = MoveGenerator{}; const auto moves = - move_generator.GenerateMoves(position); + move_generator.GenerateMoves(position); size_t answer{}; @@ -50,16 +50,16 @@ size_t Perft(std::ostream& o_stream, Position& position, const Depth depth) { PerftResult PerftBench(Position& position, Depth depth) { const auto start_time = std::chrono::high_resolution_clock::now(); - + std::ostringstream dummy_stream; const auto nodes = Perft(dummy_stream, position, depth); - + const auto time = std::chrono::duration( std::chrono::high_resolution_clock::now() - start_time) .count(); - + const auto nps = static_cast(nodes / time); - + return PerftResult{nodes, nps}; } diff --git a/Chess/Position.cpp b/Chess/Position.cpp index ec0144f..c32b010 100644 --- a/Chess/Position.cpp +++ b/Chess/Position.cpp @@ -1,9 +1,11 @@ #include "Position.h" +#include #include #include "Attacks.h" #include "BitBoard.h" +#include "MoveGenerator.h" #include "PSQT.h" #include "Piece.h" @@ -90,7 +92,7 @@ void Position::MovePiece(const BitIndex from, const BitIndex to, hash_ ^= hasher_.psqt_hash[piece_idx][color_idx][to]; } -void Position::DoMove(const Move &move) { +void Position::DoMove(LegalMove move) { if (const auto &ep_square = irreversible_data_.en_croissant_square; ep_square.has_value()) { hash_ ^= hasher_.en_croissant_hash[GetCoordinates(ep_square.value()).first]; @@ -106,13 +108,15 @@ void Position::DoMove(const Move &move) { const auto to = move.To(); const auto us = side_to_move_; const auto them = Flip(us); - + // Store captured piece for undo irreversible_data_.captured_piece = board_[to]; if (move.IsEnPassant()) { - const auto capture_square = Shift(to, kPawnMoveDirection[static_cast(them)]); - irreversible_data_.captured_piece = Piece::kPawn; // En passant always captures a pawn + const auto capture_square = + Shift(to, kPawnMoveDirection[static_cast(them)]); + irreversible_data_.captured_piece = + Piece::kPawn; // En passant always captures a pawn RemovePiece(capture_square, them); MovePiece(from, to, us); } else if (move.IsPromotion()) { @@ -121,25 +125,29 @@ void Position::DoMove(const Move &move) { RemovePiece(from, us); if (!!captured_piece) RemovePiece(to, them); PlacePiece(to, promoted_to, us); - + for (const auto castling_side : {CastlingSide::k00, CastlingSide::k000}) { - const auto their_rook = rook_positions_[static_cast(them)][static_cast(castling_side)]; + const auto their_rook = + rook_positions_[static_cast(them)] + [static_cast(castling_side)]; if (to == their_rook) { irreversible_data_.castling_rights[static_cast(them)] &= - ~static_cast(kCastlingRightsForSide[static_cast(castling_side)]); + ~static_cast( + kCastlingRightsForSide[static_cast(castling_side)]); } } } else if (move.IsCastling()) { - irreversible_data_.captured_piece = Piece::kNone; // Castling never captures + irreversible_data_.captured_piece = + Piece::kNone; // Castling never captures // Determine castling side and rook position from move const auto king_to = to; const auto king_from = from; - + // Determine which side based on king destination const auto color_idx = static_cast(us); CastlingSide side; BitIndex rook_from; - + if (king_to == kKingCastlingDestination[color_idx][0]) { side = CastlingSide::k00; rook_from = rook_positions_[color_idx][0]; @@ -147,45 +155,50 @@ void Position::DoMove(const Move &move) { side = CastlingSide::k000; rook_from = rook_positions_[color_idx][1]; } - + const auto side_idx = static_cast(side); RemovePiece(king_from, us); RemovePiece(rook_from, us); PlacePiece(kKingCastlingDestination[color_idx][side_idx], Piece::kKing, us); PlacePiece(kRookCastlingDestination[color_idx][side_idx], Piece::kRook, us); - + king_position_[color_idx] = kKingCastlingDestination[color_idx][side_idx]; irreversible_data_.castling_rights[color_idx] = 0; } else { // Normal move (including pawn pushes and double pushes) const auto piece_to_move = board_[from]; const auto captured_piece = board_[to]; - + // Check if it's a double pawn push if (piece_to_move == Piece::kPawn && std::abs(from - to) == 16) { const auto file = GetCoordinates(from).first; hash_ ^= hasher_.en_croissant_hash[file]; irreversible_data_.en_croissant_square = std::midpoint(from, to); } - + if (!!captured_piece) RemovePiece(to, them); MovePiece(from, to, us); - + if (piece_to_move == Piece::kKing) { king_position_[static_cast(us)] = to; irreversible_data_.castling_rights[static_cast(us)] = 0; } - + for (auto castling_side : {CastlingSide::k00, CastlingSide::k000}) { - const auto our_rook = rook_positions_[static_cast(us)][static_cast(castling_side)]; - const auto their_rook = rook_positions_[static_cast(them)][static_cast(castling_side)]; + const auto our_rook = rook_positions_[static_cast(us)] + [static_cast(castling_side)]; + const auto their_rook = + rook_positions_[static_cast(them)] + [static_cast(castling_side)]; if (from == our_rook) { irreversible_data_.castling_rights[static_cast(us)] &= - ~static_cast(kCastlingRightsForSide[static_cast(castling_side)]); + ~static_cast( + kCastlingRightsForSide[static_cast(castling_side)]); } if (to == their_rook) { irreversible_data_.castling_rights[static_cast(them)] &= - ~static_cast(kCastlingRightsForSide[static_cast(castling_side)]); + ~static_cast( + kCastlingRightsForSide[static_cast(castling_side)]); } } } @@ -217,8 +230,7 @@ void SimpleChessEngine::Position::DoMove(NullMove) { history_stack_.Push(hash_, false); } - -void Position::UndoMove(const Move &move, const IrreversibleData &data) { +void Position::UndoMove(LegalMove move, const IrreversibleData &data) { const auto &ep_square = irreversible_data_.en_croissant_square; for (const auto color : {Player::kWhite, Player::kBlack}) { hash_ ^= hasher_.cr_hash[static_cast( @@ -228,14 +240,14 @@ void Position::UndoMove(const Move &move, const IrreversibleData &data) { if (ep_square.has_value()) { hash_ ^= hasher_.en_croissant_hash[GetCoordinates(ep_square.value()).first]; } - + const auto from = move.From(); const auto to = move.To(); const auto them = side_to_move_; // Current side (after the move was made) const auto us = Flip(them); // Side that made the move - + const auto captured_piece = irreversible_data_.captured_piece; - + irreversible_data_ = data; for (const auto color : {Player::kWhite, Player::kBlack}) { hash_ ^= hasher_.cr_hash[static_cast( @@ -243,13 +255,16 @@ void Position::UndoMove(const Move &move, const IrreversibleData &data) { .to_ulong()]; } if (irreversible_data_.en_croissant_square.has_value()) { - hash_ ^= hasher_.en_croissant_hash[GetCoordinates(irreversible_data_.en_croissant_square.value()).first]; + hash_ ^= hasher_.en_croissant_hash + [GetCoordinates(irreversible_data_.en_croissant_square.value()) + .first]; } hash_ ^= hasher_.stm_hash; side_to_move_ = Flip(side_to_move_); if (move.IsEnPassant()) { - const auto capture_square = Shift(to, kPawnMoveDirection[static_cast(them)]); + const auto capture_square = + Shift(to, kPawnMoveDirection[static_cast(them)]); MovePiece(to, from, us); PlacePiece(capture_square, Piece::kPawn, them); } else if (move.IsPromotion()) { @@ -258,31 +273,31 @@ void Position::UndoMove(const Move &move, const IrreversibleData &data) { PlacePiece(from, Piece::kPawn, us); } else if (move.IsCastling()) { const auto color_idx = static_cast(us); - + CastlingSide side; if (to == kKingCastlingDestination[color_idx][0]) { side = CastlingSide::k00; } else { side = CastlingSide::k000; } - + const auto side_idx = static_cast(side); const auto rook_from = rook_positions_[color_idx][side_idx]; - + RemovePiece(kKingCastlingDestination[color_idx][side_idx], us); RemovePiece(kRookCastlingDestination[color_idx][side_idx], us); - + PlacePiece(from, Piece::kKing, us); PlacePiece(rook_from, Piece::kRook, us); - + king_position_[color_idx] = from; } else { // Normal move const auto piece_to_move = board_[to]; - + MovePiece(to, from, us); if (!!captured_piece) PlacePiece(to, captured_piece, them); - + if (piece_to_move == Piece::kKing) { king_position_[static_cast(us)] = from; } @@ -304,9 +319,7 @@ void SimpleChessEngine::Position::UndoMove(NullMove, history_stack_.Pop(); } - -[[nodiscard]] bool Position::CanCastle( - const CastlingSide castling_side) const { +[[nodiscard]] bool Position::CanCastle(const CastlingSide castling_side) const { const auto us = side_to_move_; const auto us_idx = static_cast(us); const auto cs_idx = static_cast(castling_side); @@ -432,7 +445,7 @@ bool Position::StaticExchangeEvaluation(const Move &move, const auto from = move.From(); const auto to = move.To(); - + Piece captured_piece = Piece::kNone; if (move.IsEnPassant()) { captured_piece = Piece::kPawn; @@ -441,7 +454,7 @@ bool Position::StaticExchangeEvaluation(const Move &move, } Piece next_victim = GetPieceAt(from); - + if (move.IsPromotion()) { next_victim = move.PromotionPiece(); } @@ -449,7 +462,8 @@ bool Position::StaticExchangeEvaluation(const Move &move, Eval balance = EstimatePiece(captured_piece); if (move.IsPromotion()) { - balance += EstimatePiece(move.PromotionPiece()) - EstimatePiece(Piece::kPawn); + balance += + EstimatePiece(move.PromotionPiece()) - EstimatePiece(Piece::kPawn); } else if (move.IsEnPassant()) { balance = EstimatePiece(Piece::kPawn); } @@ -469,7 +483,8 @@ bool Position::StaticExchangeEvaluation(const Move &move, Bitboard occupancy = GetAllPieces() ^ SingleSquare(from) ^ SingleSquare(to); [[unlikely]] if (move.IsEnPassant()) { - occupancy ^= Shift(SingleSquare(to), kPawnMoveDirection[static_cast(them)]); + occupancy ^= + Shift(SingleSquare(to), kPawnMoveDirection[static_cast(them)]); } Bitboard attackers = Attackers(to, ~occupancy) & occupancy; @@ -560,7 +575,7 @@ Bitboard Position::Attackers(const BitIndex square, pieces_by_type_[static_cast(Piece::kKing)]); } -void Position::ComputePins(const Player us) { +void Position::ComputePins(const Player us) const { const Player them = Flip(us); const auto us_idx = static_cast(us); @@ -634,4 +649,151 @@ size_t Position::GameHistory::Count(const Hash hash, Depth depth) const { } return result; } + +bool Position::PseudoLegal(const Move &move) const { + const auto us = side_to_move_; + const auto from = move.From(); + const auto to = move.To(); + const auto piece = GetPieceAt(from); + + if (!move.IsNormal()) { + MoveGenerator generator; + const auto moves = + IsUnderCheck(us) + ? generator.GenerateMoves(*this) + : generator.GenerateMoves(*this); + return std::find(moves.begin(), moves.end(), move) != moves.end(); + } + + if (piece == Piece::kNone || GetPieces(us).Test(from) == false) { + return false; + } + + if (GetPieces(us).Test(to)) { + return false; + } + + if (piece == Piece::kPawn) { + if ((kRankBB[7] | kRankBB[0]).Test(to)) { + return false; + } + + const auto is_capture = + (GetPawnAttacks(from, us) & GetPieces(Flip(us))).Test(to); + const auto is_single_push = + (Shift(from, kPawnMoveDirection[static_cast(us)]) == to) && + GetPieceAt(to) == Piece::kNone; + const auto is_double_push = + (Shift(from, kPawnMoveDirection[static_cast(us)]) + + static_cast( + kPawnMoveDirection[static_cast(us)]) == + to) && + (GetCoordinates(from).second == (us == Player::kWhite ? 1 : 6)) && + GetPieceAt(to) == Piece::kNone && + GetPieceAt( + Shift(to, kPawnMoveDirection[static_cast(Flip(us))])) == + Piece::kNone; + + if (!(is_capture || is_single_push || is_double_push)) { + return false; + } + } else if (!(AttackTable::GetAttackMap(from, GetAllPieces()) + .Test(to) && + piece == Piece::kKnight) && + !(AttackTable::GetAttackMap(from, GetAllPieces()) + .Test(to) && + piece == Piece::kBishop) && + !(AttackTable::GetAttackMap(from, GetAllPieces()) + .Test(to) && + piece == Piece::kRook) && + !(AttackTable::GetAttackMap(from, GetAllPieces()) + .Test(to) && + piece == Piece::kQueen) && + !(AttackTable::GetAttackMap(from, GetAllPieces()) + .Test(to) && + piece == Piece::kKing)) { + return false; + } + + if (IsUnderCheck()) { + if (piece != Piece::kKing) { + const auto checkers = Attackers(GetKingSquare(us)) & GetPieces(Flip(us)); + if (checkers.MoreThanOne()) { + return false; + } + + if (to == checkers.GetFirstBit()) { + return true; + } + + if (!Between(GetKingSquare(us), checkers.GetFirstBit()).Test(to)) { + return false; + } + } else if (IsUnderAttack(to, us, SingleSquare(from))) { + return false; + } + } + + return true; +} + +bool Position::Legal(const Move &move) const { + assert(move.IsValid()); + + const auto us = side_to_move_; + const auto from = move.From(); + const auto to = move.To(); + + assert(GetPieces(us).Test(from)); + assert(GetPieceAt(GetKingSquare(us)) == Piece::kKing); + + if (move.IsEnPassant()) { + const auto king_square = GetKingSquare(us); + const auto capture_square = + Shift(to, kPawnMoveDirection[static_cast(Flip(us))]); + const auto occupied = + (GetAllPieces() ^ SingleSquare(from) ^ SingleSquare(capture_square)) | + SingleSquare(to); + + assert(to == GetEnCroissantSquare().value()); + assert(GetPieceAt(from) == Piece::kPawn); + assert(GetPieceAt(capture_square) == Piece::kPawn); + assert(GetPieceAt(to) == Piece::kNone); + + return !(AttackTable::GetAttackMap(king_square, occupied) & + (GetPiecesByType(Flip(us)) | + GetPiecesByType(Flip(us)))) + .Any() && + !(AttackTable::GetAttackMap(king_square, occupied) & + (GetPiecesByType(Flip(us)) | + GetPiecesByType(Flip(us)))) + .Any(); + } + + if (move.IsCastling()) { + const auto color_idx = static_cast(us); + const auto final_king_pos = to > from + ? kKingCastlingDestination[color_idx][0] + : kKingCastlingDestination[color_idx][1]; + const auto step = to > from ? Compass::kWest : Compass::kEast; + + for (auto square = final_king_pos; square != from; + square = Shift(square, step)) { + if (IsUnderAttack(square, us)) { + return false; + } + } + + return true; + } + + if (GetPieceAt(from) == Piece::kKing) { + return !IsUnderAttack(to, us, SingleSquare(from)); + } + + return !(irreversible_data_.blockers[static_cast(us)].Test(from)) || + ((Ray(from, to) | Ray(to, from)) & GetPiecesByType(us)) + .Any(); +} + } // namespace SimpleChessEngine \ No newline at end of file diff --git a/Chess/Position.h b/Chess/Position.h index c6ca8b4..071611a 100644 --- a/Chess/Position.h +++ b/Chess/Position.h @@ -56,16 +56,17 @@ class Position { std::array, kColors> castling_rights{}; //!< Castling rights for each color. - std::array + mutable std::array pinners{}; //!< Pieces that are pinning opponent's pieces. - std::array + mutable std::array blockers{}; //!< Pieces that are blocking attacks on the king. - + Piece captured_piece{Piece::kNone}; //!< Piece captured by the last move bool operator==(const IrreversibleData &other) const { return std::tie(en_croissant_square, castling_rights, captured_piece) == - std::tie(other.en_croissant_square, other.castling_rights, other.captured_piece); + std::tie(other.en_croissant_square, other.castling_rights, + other.captured_piece); } }; @@ -112,20 +113,19 @@ class Position { * * \param move Move to do. */ - void DoMove(const Move &move); + void DoMove(LegalMove move); /** * \brief Undoes given move. * * \param move Move to undo. */ - void UndoMove(const Move &move, const IrreversibleData &data); + void UndoMove(LegalMove move, const IrreversibleData &data); void DoMove(NullMove); void UndoMove(NullMove, const IrreversibleData &data); - [[nodiscard]] bool CanCastle( - const CastlingSide castling_side) const; + [[nodiscard]] bool CanCastle(const CastlingSide castling_side) const; /** * \brief Gets hash of the position. @@ -180,13 +180,13 @@ class Position { [[nodiscard]] BitIndex GetKingSquare(Player player) const; - [[nodiscard]] BitIndex GetCastlingRookSquare( - Player player, CastlingSide side) const; + [[nodiscard]] BitIndex GetCastlingRookSquare(Player player, + CastlingSide side) const; [[nodiscard]] Bitboard Attackers(BitIndex square, Bitboard transparent = kEmptyBoard) const; - void ComputePins(Player us); + void ComputePins(Player us) const; [[nodiscard]] bool IsUnderAttack(BitIndex square, Player us, Bitboard transparent = kEmptyBoard) const; @@ -215,6 +215,10 @@ class Position { [[nodiscard]] bool StaticExchangeEvaluation(const Move &move, Eval threshold) const; + [[nodiscard]] bool PseudoLegal(const Move &move) const; + + [[nodiscard]] bool Legal(const Move &move) const; + [[nodiscard]] const std::optional &GetEnCroissantSquare() const; [[nodiscard]] const std::array, kColors> &GetCastlingRights() @@ -271,7 +275,6 @@ class Position { void MovePiece(const BitIndex from, const BitIndex to, const Player color); - void SetCastlingRights(const std::array, 2> &castling_rights); void SetKingPositions(const std::array &king_position); @@ -309,4 +312,52 @@ class Position { Hash hash_{}; }; +template +std::optional MoveCast(From move, const Position &position); + +template <> +inline std::optional MoveCast( + Move move, const Position &position) { + if (!move || !position.PseudoLegal(move)) { + return std::nullopt; + } + return UnsafeMoveCast(move); +} + +template <> +inline std::optional MoveCast( + PseudoLegalMove move, const Position &position) { + if (!position.Legal(move)) { + return std::nullopt; + } + return UnsafeMoveCast(move); +} + +template <> +inline std::optional MoveCast( + Move move, const Position &position) { + if (!position.PseudoLegal(move) || !position.Legal(move)) { + return std::nullopt; + } + return UnsafeMoveCast(move); +} + +template <> +inline std::optional MoveCast( + PseudoLegalMoveRef move, const Position &position) { + return MoveCast(static_cast(move), position); +} + +template <> +inline std::optional MoveCast( + PseudoLegalMoveConstRef move, const Position &position) { + return MoveCast(static_cast(move.get()), position); +} + +template +To UnsafeMoveCast(From move, [[maybe_unused]] const Position &position) { + assert(MoveCast(move, position).has_value()); + return UnsafeMoveCast(move); +} + } // namespace SimpleChessEngine \ No newline at end of file diff --git a/Chess/Quiescence.cpp b/Chess/Quiescence.cpp index fc81687..89882ec 100644 --- a/Chess/Quiescence.cpp +++ b/Chess/Quiescence.cpp @@ -3,6 +3,8 @@ #include #include "ExitCondition.h" +#include "Move.h" +#include "Position.h" namespace SimpleChessEngine { template class Quiescence; @@ -54,7 +56,7 @@ SearchResult Quiescence::Search(Position& current_position, } // get all the attacks moves - auto moves = move_generator_.GenerateMoves( + auto moves = move_generator_.GenerateMoves( current_position); /* @@ -66,14 +68,17 @@ SearchResult Quiescence::Search(Position& current_position, for (const auto& move : moves) { if (!current_position.StaticExchangeEvaluation( - move, std::max(1, alpha - stand_pat - kSEEMargin))) { + move, std::max(1, alpha - stand_pat - kSEEMargin))) { continue; } + const auto legal_move = MoveCast(move, current_position); + if (!legal_move) continue; + const auto irreversible_data = current_position.GetIrreversibleData(); // make the move and search the tree - current_position.DoMove(move); + current_position.DoMove(*legal_move); const auto temp_eval_optional = Search(current_position, -beta, -alpha, current_depth + 1); @@ -82,7 +87,7 @@ SearchResult Quiescence::Search(Position& current_position, const auto temp_eval = -*temp_eval_optional; // undo the move - current_position.UndoMove(move, irreversible_data); + current_position.UndoMove(*legal_move, irreversible_data); if (temp_eval > alpha) { if (temp_eval >= beta) { @@ -101,12 +106,13 @@ template SearchResult Quiescence::SearchUnderCheck( Position& current_position, Eval alpha, Eval beta, const Depth current_depth) { - MoveGenerator::Moves moves = - move_generator_.GenerateMoves( - current_position); + auto moves = move_generator_.GenerateMoves( + current_position); + + bool found_legal_move = false; if (moves.empty()) { - return kMateValue + kMaxSearchPly; + return kMateValue + Eval(kMaxSearchPly); } /* @@ -119,8 +125,12 @@ SearchResult Quiescence::SearchUnderCheck( for (const auto& move : moves) { const auto irreversible_data = current_position.GetIrreversibleData(); - // make the move and search the tree - current_position.DoMove(move); + const auto legal_move = MoveCast(move, current_position); + if (!legal_move) { + continue; + } + found_legal_move = true; + current_position.DoMove(*legal_move); const auto temp_eval_optional = Search(current_position, -beta, -alpha, current_depth + 1); @@ -129,7 +139,7 @@ SearchResult Quiescence::SearchUnderCheck( const auto temp_eval = -*temp_eval_optional; // undo the move - current_position.UndoMove(move, irreversible_data); + current_position.UndoMove(*legal_move, irreversible_data); if (temp_eval > alpha) { if (temp_eval >= beta) { @@ -140,6 +150,10 @@ SearchResult Quiescence::SearchUnderCheck( } } + if (!found_legal_move) { + return kMateValue + Eval(kMaxSearchPly); + } + return alpha; } @@ -149,4 +163,4 @@ bool Quiescence::IsTimeToExit() { return searched_nodes_ % kEnoughNodesToCheckTime == 0 && exit_condition_.IsTimeToExit(); } -} // namespace SimpleChessEngine \ No newline at end of file +} // namespace SimpleChessEngine diff --git a/Chess/SearchImplementation.h b/Chess/SearchImplementation.h index a634ef2..0e051e9 100644 --- a/Chess/SearchImplementation.h +++ b/Chess/SearchImplementation.h @@ -80,7 +80,7 @@ struct SearchNode { SearchResult QuiescenceSearch(); Eval GetEndGameScore() const; - void SetBestMove(Move move); + void SetBestMove(LegalMove move); void SetTTEntry(const Bound bound); template void UpdateQuietMove(const Move &move); @@ -88,9 +88,9 @@ struct SearchNode { Position &GetCurrentPosition(); template - SearchResult ProbeMove(const Move &move); + SearchResult ProbeMove(LegalMove move); template - std::optional CheckFirstMove(const Move &move); + std::optional CheckFirstMove(LegalMove move); [[nodiscard]] bool CanRFP() const; @@ -176,7 +176,7 @@ SearchResult SearchNode::operator()() { if (ProbeTranspositionTable()) { searcher_.debug_info_.tt_hits++; - auto [hash, hash_move, entry_score, entry_depth, entry_bound, _] = + auto [hash_move, entry_score, entry_depth, entry_bound, is_pv] = *iteration_status_.tt_info; entry_score -= IsMateScore(entry_score) * (max_depth - remaining_depth); @@ -206,6 +206,11 @@ SearchResult SearchNode::operator()() { return entry_score; } } + } else if constexpr (kIsPrincipalVariation) { + searcher_.debug_info_.tt_pv_misses += remaining_depth > 1; + } + else { + searcher_.debug_info_.tt_other_misses += remaining_depth > 1; } if (CanRFP()) { @@ -225,7 +230,7 @@ SearchResult SearchNode::operator()() { static_cast( remaining_depth - Settings::PruneParameters::NMPSettings::kNullMoveReduction), - -beta, -beta + 1, true}); + -beta, -beta + Eval(1), true}); current_position.UndoMove(NullMove{}, position_info_.irreversible_data); @@ -258,7 +263,8 @@ SearchResult SearchNode::operator()() { auto const &move_generator = searcher_.move_generator_; move_picker_.InitPicker( - move_generator.GenerateMoves(current_position), + move_generator.GenerateMoves( + current_position), searcher_); // check if there are no possible moves @@ -266,7 +272,11 @@ SearchResult SearchNode::operator()() { return GetEndGameScore(); } - if (!iteration_status_.best_move) { + if (!iteration_status_.best_move || + !move_picker_.SkipMove(*iteration_status_.best_move)) { + if (iteration_status_.best_move) { + searcher_.debug_info_.tt_wrong_moves++; + } auto has_cutoff_opt = CheckFirstMove( *move_picker_.SelectNextMove(searcher_, max_depth - remaining_depth)); if (!has_cutoff_opt) { @@ -276,10 +286,6 @@ SearchResult SearchNode::operator()() { SetTTEntry(Bound::kLower); return beta; } - } else { - // skip the first move - assert(iteration_status_.best_move); - move_picker_.SkipMove(*iteration_status_.best_move); } return PVSearch(); @@ -306,8 +312,7 @@ template requires StopSearchCondition Eval SearchNode::GetEndGameScore() const { if (position_info_.is_under_check) { - return kMateValue + - static_cast(state_.max_depth - state_.remaining_depth); + return kMateValue + Eval(state_.max_depth - state_.remaining_depth); } return kDrawValue; @@ -315,7 +320,7 @@ Eval SearchNode::GetEndGameScore() const { template requires StopSearchCondition -void SearchNode::SetBestMove(Move move) { +void SearchNode::SetBestMove(LegalMove move) { if (state_.remaining_depth == state_.max_depth) { searcher_.best_move_ = move; } @@ -326,18 +331,21 @@ template requires StopSearchCondition void SearchNode::SetTTEntry(const Bound bound) { assert(iteration_status_.best_move); - searcher_.best_moves_.SetEntry( - GetCurrentPosition(), *iteration_status_.best_move, - iteration_status_.best_eval + - IsMateScore(iteration_status_.best_eval) * - (state_.max_depth - state_.remaining_depth), - state_.remaining_depth, bound, searcher_.age_); + auto result = searcher_.best_moves_.Probe(GetCurrentPosition().GetHash()); + result.entry.get().Save(GetCurrentPosition().GetHash(), + iteration_status_.best_eval + + IsMateScore(iteration_status_.best_eval) * + (state_.max_depth - state_.remaining_depth), + kIsPrincipalVariation, bound, state_.remaining_depth, + *iteration_status_.best_move, + iteration_status_.best_eval, + searcher_.best_moves_.GetGeneration()); } template requires StopSearchCondition template -SearchResult SearchNode::ProbeMove(const Move &move) { +SearchResult SearchNode::ProbeMove(LegalMove move) { auto ¤t_position = GetCurrentPosition(); auto &[max_depth, remaining_depth, alpha, beta, _] = state_; @@ -361,7 +369,7 @@ template requires StopSearchCondition template std::optional SearchNode::CheckFirstMove( - const Move &move) { + LegalMove move) { static_assert(expected_node_type == kFirstChildNodeExpectedType); static_assert(expected_node_type != NodeType::kPV || expected_node_type == node_type); @@ -431,8 +439,8 @@ SearchResult SearchNode::PVSearch() { } auto temp_eval_optional = StartSubsearch( - {max_depth, static_cast(remaining_depth - 1 - R), -alpha - 1, - -alpha}); // Reduced ZWS + {max_depth, static_cast(remaining_depth - 1 - R), + -alpha - Eval(1), -alpha}); if (!temp_eval_optional) { current_position.UndoMove(move, position_info_.irreversible_data); @@ -444,7 +452,7 @@ SearchResult SearchNode::PVSearch() { temp_eval > alpha) { /* research at full depth, but still with zero window */ temp_eval_optional = StartSubsearch( - {max_depth, static_cast(remaining_depth - 1), -alpha - 1, + {max_depth, static_cast(remaining_depth - 1), -alpha - Eval(1), -alpha}); if (!temp_eval_optional) { @@ -583,9 +591,9 @@ bool SearchNode::CanRFP() const { template requires StopSearchCondition bool SearchNode::ProbeTranspositionTable() { - if (auto node = searcher_.best_moves_.GetNode(searcher_.current_position_); - node.true_hash == GetCurrentPosition().GetHash()) { - iteration_status_.tt_info = std::move(node); + auto result = searcher_.best_moves_.Probe(GetCurrentPosition().GetHash()); + if (result.found) { + iteration_status_.tt_info = std::move(result.data); return true; } @@ -598,11 +606,16 @@ std::optional SearchNode::CheckTranspositionTable() { if (iteration_status_.tt_info) { auto &[max_depth, remaining_depth, alpha, beta, _] = state_; - auto [hash, hash_move, entry_score, entry_depth, entry_bound, _] = + auto [hash_move, entry_score, entry_depth, entry_bound, is_pv] = *iteration_status_.tt_info; + GetCurrentPosition().ComputePins(GetCurrentPosition().GetSideToMove()); + auto legal_move = MoveCast(hash_move, GetCurrentPosition()); + if (!legal_move) { + return std::nullopt; + } auto has_cutoff_opt = - CheckFirstMove(hash_move); + CheckFirstMove(*legal_move); if (!has_cutoff_opt) { return SearchResult{std::nullopt}; } diff --git a/Chess/Searcher.h b/Chess/Searcher.h index 34cc9d0..35e8522 100644 --- a/Chess/Searcher.h +++ b/Chess/Searcher.h @@ -36,8 +36,7 @@ class Searcher { template requires StopSearchCondition friend struct SearchNode; - constexpr static size_t kTTSizeInMb = 640; - using SearcherTranspositionTable = TranspositionTable; + using SearcherTranspositionTable = TranspositionTable; /** * \brief Constructor. @@ -66,7 +65,7 @@ class Searcher { * * \return The current best move. */ - [[nodiscard]] const Move &GetCurrentBestMove() const; + [[nodiscard]] LegalMove GetCurrentBestMove() const; /** * \brief Performs the alpha-beta search algorithm. @@ -89,12 +88,13 @@ class Searcher { [[nodiscard]] const auto &GetKillers() const { return killers_; } [[nodiscard]] const auto &GetHistory() const { return history_; } - [[nodiscard]] MoveGenerator::Moves GetPrincipalVariation( - Depth max_depth, Position position) const; + [[nodiscard]] MoveList GetPrincipalVariation(Depth max_depth, + Position position); + [[nodiscard]] int HashFull() const { return best_moves_.HashFull(); } private: Age age_{}; - Move best_move_{}; + LegalMove best_move_{}; Position current_position_; //!< Current position. MoveGenerator move_generator_; //!< Move generator. SearcherTranspositionTable @@ -120,18 +120,28 @@ inline void Searcher::SetPosition(Position position) { current_position_ = std::move(position); } -inline const Position &Searcher::GetPosition() const { return current_position_; } +inline const Position &Searcher::GetPosition() const { + return current_position_; +} -inline const Move &Searcher::GetCurrentBestMove() const { return best_move_; } +inline LegalMove Searcher::GetCurrentBestMove() const { return best_move_; } -inline MoveGenerator::Moves Searcher::GetPrincipalVariation(Depth max_depth, - Position position) const { - MoveGenerator::Moves answer; +inline MoveList Searcher::GetPrincipalVariation(Depth max_depth, + Position position) { + MoveList answer; for (Depth i = 0; i < max_depth; ++i) { - const auto &hashed_node = best_moves_.GetNode(position); - if (hashed_node.true_hash != position.GetHash()) break; - position.DoMove(hashed_node.move); - answer.push_back(hashed_node.move); + const auto result = best_moves_.Probe(position.GetHash()); + if (!result.found) break; + + const auto pseudo_legal = + MoveCast(result.data.move, position); + if (!pseudo_legal) break; + + const auto legal = MoveCast(*pseudo_legal, position); + if (!legal) break; + + position.DoMove(*legal); + answer.push_back(*legal); } return answer; } @@ -143,6 +153,7 @@ inline void Searcher::InitStartOfSearch() { history_[color][from].fill(0LL); } } + best_moves_.NewSearch(); } template diff --git a/Chess/SimpleChessEngine.h b/Chess/SimpleChessEngine.h index 0ed13e2..06824f3 100644 --- a/Chess/SimpleChessEngine.h +++ b/Chess/SimpleChessEngine.h @@ -33,11 +33,11 @@ struct NodePerSecondInfo { struct PrincipalVariationInfo { size_t current_depth = 0; - MoveGenerator::Moves moves; + MoveList moves; }; struct BestMoveInfo { - const Move& move; + LegalMove move; std::optional ponder; }; @@ -92,7 +92,7 @@ class ChessEngine { void ComputeBestMove(SearchCondition auto& conditions); - [[nodiscard]] const Move& GetCurrentBestMove() const; + [[nodiscard]] LegalMove GetCurrentBestMove() const; void PrintBestMove() { o_stream_ << BestMoveInfo{GetCurrentBestMove(), std::nullopt}; @@ -144,6 +144,9 @@ class ChessEngine { o_stream_ << "info quiescence_nodes " << info.quiescence_nodes << "\n"; o_stream_ << "info zws_researches " << info.zws_researches << "\n"; o_stream_ << "info tt_hits " << info.tt_hits << "\n"; + o_stream_ << "info tt_pv_misses " << info.tt_pv_misses << "\n"; + o_stream_ << "info tt_other_misses " << info.tt_other_misses << "\n"; + o_stream_ << "info tt_wrong_moves " << info.tt_wrong_moves << "\n"; o_stream_ << "info tt_cuts " << info.tt_cuts << "\n"; o_stream_ << "info rfp_cuts " << info.rfp_cuts << "\n"; o_stream_ << "info nmp_tries " << info.nmp_tries << "\n"; @@ -201,8 +204,8 @@ class ChessEngine { Searcher searcher_; Position position_; - Move best_move_; - std::optional ponder_move_; + LegalMove best_move_; + std::optional ponder_move_; }; } // namespace SimpleChessEngine @@ -250,6 +253,7 @@ inline void SimpleChessEngine::ChessEngine::ComputeBestMove( std::chrono::system_clock::now() - start_time) .count() << "\n"; + o_stream_ << "info hashfull " << searcher_.HashFull() << "\n"; if (auto two_move_pv = searcher_.GetPrincipalVariation(2, position_); two_move_pv.size() > 1) { @@ -266,7 +270,7 @@ inline void SimpleChessEngine::ChessEngine::ComputeBestMove( PrintBestMove(BestMoveInfo{best_move_, ponder_move_}); } -inline const Move& ChessEngine::GetCurrentBestMove() const { +inline LegalMove ChessEngine::GetCurrentBestMove() const { return searcher_.GetCurrentBestMove(); } diff --git a/Chess/StreamUtility.h b/Chess/StreamUtility.h index af150e5..97577d4 100644 --- a/Chess/StreamUtility.h +++ b/Chess/StreamUtility.h @@ -49,7 +49,7 @@ inline std::ostream& PrintCoordinates(const Coordinates coordinates, return stream; } -inline std::ostream& operator<<(std::ostream& stream, const Move& move) { +inline std::ostream& operator<<(std::ostream& stream, Move move) { const auto from = GetCoordinates(move.From()); const auto to = GetCoordinates(move.To()); diff --git a/Chess/TranspositionTable.h b/Chess/TranspositionTable.h index 1b5739f..1100870 100644 --- a/Chess/TranspositionTable.h +++ b/Chess/TranspositionTable.h @@ -1,13 +1,24 @@ #pragma once #include -#include +#include +#include +#include #include +#include "Eval.h" #include "Hasher.h" #include "Move.h" -#include "Position.h" +#include "Utility.h" + namespace SimpleChessEngine { + +class Position; + +using ShortHash = std::uint16_t; +using Generation = std::uint8_t; + enum class Bound : std::uint8_t { + kNone = 0, kLower = 1, kUpper = 2, kExact = kLower | kUpper @@ -17,49 +28,168 @@ inline std::uint8_t operator&(const Bound lhs, const Bound rhs) { return static_cast(lhs) & static_cast(rhs); } -#pragma pack(push, 1) struct Node { - Hash true_hash{}; Move move{}; Eval score{}; - Depth depth : 6 {}; + Depth depth{}; + Bound bound{Bound::kNone}; + bool is_pv_node{false}; + + Node() = default; + + Node(Move m, Eval s, Depth d, Bound b, bool pv) + : move(m), score(s), depth(d), bound(b), is_pv_node(pv) {} +}; + +constexpr std::uint8_t kGenerationSizeInBits = 5; + +#pragma pack(push, 1) +struct TableEntry { + ShortHash short_hash{0}; + Depth depth_stored{0}; + Generation generation : kGenerationSizeInBits{}; Bound bound : 2 {}; - Age age{}; + bool is_pv : 1 {}; + Move move{}; + Eval score{0}; + Eval static_eval{0}; + + [[nodiscard]] bool IsOccupied() const { return depth_stored != 0; } + [[nodiscard]] Node Read() const; + void Save(Hash key, Eval score, bool is_pv_node, Bound bound, Depth depth, + Move move, Eval static_eval, Generation generation); + [[nodiscard]] std::uint8_t RelativeAge(Generation current_generation) const; }; #pragma pack(pop) -template +struct EntryCluster { + static constexpr size_t kClusterSize = 3; + std::array entries{}; + std::array padding{}; +}; + +static_assert(sizeof(TableEntry) == 10, "TableEntry must be 10 bytes"); +static_assert(sizeof(EntryCluster) == 32, + "EntryCluster must be 32 bytes for optimal cache performance"); + +struct ProbeResult { + bool found; + Node data; + std::reference_wrapper entry; +}; + class TranspositionTable { public: - static constexpr size_t kTableSize = 1 << 25; + static constexpr size_t kClusterSize = EntryCluster::kClusterSize; + static constexpr size_t kBytesPerCluster = sizeof(EntryCluster); + static constexpr size_t kDefaultSizeMB = 640; + static constexpr size_t kClusterCount = + (kDefaultSizeMB * 1024 * 1024) / kBytesPerCluster; - [[nodiscard]] bool Contains(const Position& position) const { - return position.GetHash() == GetNode(position).true_hash; - } + TranspositionTable() { Clear(); } + + void Clear(); + void NewSearch() { generation_ += kGenerationDelta; } + [[nodiscard]] Generation GetGeneration() const { return generation_; } + [[nodiscard]] int HashFull(int max_age = 0) const; + [[nodiscard]] ProbeResult Probe(Hash key); + + private: + [[nodiscard]] std::array::iterator GetFirstEntry(Hash key); + + std::vector table_; + Generation generation_{0}; - void SetEntry(const Position& position, const Move& move, const Eval score, - const Depth depth, const Bound bound, const Age age) { - Node inserting_node = {position.GetHash(), move, score, depth, bound, age}; - auto& entry_node = GetNode(position); - if (bound == Bound::kExact || - !(entry_node.bound == Bound::kExact && entry_node.age == age)) { - entry_node = inserting_node; + static constexpr Generation kGenerationDelta = 1; + static constexpr std::uint8_t kGenerationMask = 0x1F; +}; + +inline void TranspositionTable::Clear() { + table_.clear(); + table_.resize(kClusterCount); + generation_ = 0; +} + +inline int TranspositionTable::HashFull(int max_age) const { + int count = 0; + + const size_t sample_size = std::min(1000, kClusterCount); + for (size_t i = 0; i < sample_size; ++i) { + for (size_t j = 0; j < kClusterSize; ++j) { + const auto& entry = table_[i].entries[j]; + if (entry.IsOccupied() && entry.RelativeAge(generation_) <= max_age) { + ++count; + } } } - const Move& GetMove(const Position& position) const { - assert(Contains(position)); - return GetNode(position).move; + return count / kClusterSize; +} + +inline ProbeResult TranspositionTable::Probe(Hash key) { + auto first_entry = GetFirstEntry(key); + const ShortHash short_hash = static_cast(key); + for (size_t i = 0; i < kClusterSize; ++i) { + if (first_entry[i].short_hash == short_hash) { + return {first_entry[i].IsOccupied(), first_entry[i].Read(), + std::ref(first_entry[i])}; + } } - Node& GetNode(const Position& position) { - return table_[position.GetHash() % kTableSize]; + auto replace = first_entry; + for (size_t i = 1; i < kClusterSize; ++i) { + const auto replace_value = + replace->depth_stored - replace->RelativeAge(generation_); + const auto candidate_value = + first_entry[i].depth_stored - first_entry[i].RelativeAge(generation_); + + if (replace_value > candidate_value) { + replace = &first_entry[i]; + } } - const Node& GetNode(const Position& position) const { - return table_[position.GetHash() % kTableSize]; + return {false, Node{Move::None(), 0, 0, Bound::kNone, false}, + std::ref(*replace)}; +} + +inline std::array::iterator TranspositionTable::GetFirstEntry( + Hash key) { + const size_t cluster_index = static_cast( + (static_cast<__uint128_t>(key) * kClusterCount) >> 64); + return table_[cluster_index].entries.begin(); +} + +inline Node TableEntry::Read() const { + return Node{move, score, static_cast(depth_stored), bound, is_pv}; +} + +inline void TableEntry::Save(Hash key, Eval score_value, bool is_pv_node, + Bound bound_value, Depth depth_value, + Move move_value, Eval static_eval_value, + Generation gen) { + const bool should_replace = !IsOccupied() || bound_value == Bound::kExact || + (!(this->bound == Bound::kExact) && + depth_value + 4 > this->depth_stored) || + RelativeAge(gen) > 0; + + if (should_replace) { + this->short_hash = static_cast(key); + this->depth_stored = static_cast(depth_value); + this->generation = gen; + this->bound = bound_value; + this->is_pv = is_pv_node; + this->score = score_value; + this->static_eval = static_eval_value; + this->move = move_value; } +} + +inline std::uint8_t TableEntry::RelativeAge( + Generation current_generation) const { + constexpr std::uint8_t kMaxGeneration = (1 << kGenerationSizeInBits) - 1; + constexpr std::uint8_t kGenerationRange = 1 << kGenerationSizeInBits; + current_generation &= kMaxGeneration; + return (kGenerationRange + current_generation - generation) & kMaxGeneration; +} - std::vector table_ = std::vector(kTableSize); //!< The table. -}; } // namespace SimpleChessEngine diff --git a/Chess/UciCommunicator.h b/Chess/UciCommunicator.h index d5b6ea9..f7268b8 100644 --- a/Chess/UciCommunicator.h +++ b/Chess/UciCommunicator.h @@ -9,6 +9,7 @@ #include #include +#include "Move.h" #include "MoveFactory.h" #include "Perft.h" #include "Position.h" @@ -329,7 +330,16 @@ inline void UciChessEngine::ParseMoves(std::stringstream command) { std::string move; while (!command.eof()) { command >> move; - info_.position.DoMove(MoveFactory{}(info_.position, move)); + const auto parsed_move = MoveFactory{}(info_.position, move); + if (auto legal_move = MoveCast(parsed_move, info_.position)) { + info_.position.DoMove(*legal_move); + } else { + Send("Illegal move: " + move + " PseudoLegal: " + + std::to_string(info_.position.PseudoLegal(parsed_move)) + + " Legal: " + std::to_string(info_.position.Legal(parsed_move))); + info_.position.DoMove(UnsafeMoveCast(parsed_move)); + } + info_.position.ComputePins(info_.position.GetSideToMove()); } } diff --git a/Tests/CompactMoveTests.cpp b/Tests/CompactMoveTests.cpp index 0fef689..911c0d4 100644 --- a/Tests/CompactMoveTests.cpp +++ b/Tests/CompactMoveTests.cpp @@ -7,109 +7,103 @@ using namespace SimpleChessEngine; namespace CompactMoveTests { TEST(BasicConstruction, FromToSquares) { - const Move move(0, 8); - ASSERT_EQ(move.From(), 0); - ASSERT_EQ(move.To(), 8); - ASSERT_TRUE(move.IsNormal()); - ASSERT_TRUE(move.IsValid()); + const Move move(0, 8); + ASSERT_EQ(move.From(), 0); + ASSERT_EQ(move.To(), 8); + ASSERT_TRUE(move.IsNormal()); + ASSERT_TRUE(move.IsValid()); } TEST(PromotionMoves, AllPieces) { - const Move queen_promo = Move::Make(48, 56, Piece::kQueen); - ASSERT_TRUE(queen_promo.IsPromotion()); - ASSERT_EQ(queen_promo.PromotionPiece(), Piece::kQueen); - - const Move knight_promo = Move::Make(48, 56, Piece::kKnight); - ASSERT_EQ(knight_promo.PromotionPiece(), Piece::kKnight); - - const Move rook_promo = Move::Make(48, 56, Piece::kRook); - ASSERT_EQ(rook_promo.PromotionPiece(), Piece::kRook); - - const Move bishop_promo = Move::Make(48, 56, Piece::kBishop); - ASSERT_EQ(bishop_promo.PromotionPiece(), Piece::kBishop); + const Move queen_promo = + Move::Make(48, 56, Piece::kQueen); + ASSERT_TRUE(queen_promo.IsPromotion()); + ASSERT_EQ(queen_promo.PromotionPiece(), Piece::kQueen); + + const Move knight_promo = + Move::Make(48, 56, Piece::kKnight); + ASSERT_EQ(knight_promo.PromotionPiece(), Piece::kKnight); + + const Move rook_promo = + Move::Make(48, 56, Piece::kRook); + ASSERT_EQ(rook_promo.PromotionPiece(), Piece::kRook); + + const Move bishop_promo = + Move::Make(48, 56, Piece::kBishop); + ASSERT_EQ(bishop_promo.PromotionPiece(), Piece::kBishop); } TEST(SpecialMoves, EnPassantAndCastling) { - const Move en_passant = Move::Make(32, 41); - ASSERT_TRUE(en_passant.IsEnPassant()); - ASSERT_FALSE(en_passant.IsPromotion()); - ASSERT_FALSE(en_passant.IsCastling()); - ASSERT_FALSE(en_passant.IsNormal()); - - const Move castling = Move::Make(4, 6); - ASSERT_TRUE(castling.IsCastling()); - ASSERT_FALSE(castling.IsEnPassant()); - ASSERT_FALSE(castling.IsPromotion()); - ASSERT_FALSE(castling.IsNormal()); + const Move en_passant = Move::Make(32, 41); + ASSERT_TRUE(en_passant.IsEnPassant()); + ASSERT_FALSE(en_passant.IsPromotion()); + ASSERT_FALSE(en_passant.IsCastling()); + ASSERT_FALSE(en_passant.IsNormal()); + + const Move castling = Move::Make(4, 6); + ASSERT_TRUE(castling.IsCastling()); + ASSERT_FALSE(castling.IsEnPassant()); + ASSERT_FALSE(castling.IsPromotion()); + ASSERT_FALSE(castling.IsNormal()); } TEST(MoveEquality, SameAndDifferent) { - const Move move1(8, 16); - const Move move2(8, 16); - const Move move3(8, 24); - - ASSERT_EQ(move1, move2); - ASSERT_NE(move1, move3); + const Move move1(8, 16); + const Move move2(8, 16); + const Move move3(8, 24); + + ASSERT_EQ(move1, move2); + ASSERT_NE(move1, move3); } TEST(SpecialValues, NullAndNone) { - ASSERT_FALSE(Move::Null().IsValid()); - ASSERT_FALSE(Move::None().IsValid()); - ASSERT_NE(Move::Null(), Move::None()); + ASSERT_FALSE(Move::Null().IsValid()); + ASSERT_FALSE(Move::None().IsValid()); + ASSERT_NE(Move::Null(), Move::None()); } TEST(BitLayout, CorrectEncoding) { - const Move move(63, 0); - ASSERT_EQ(move.Raw() & 0x3F, 0); - ASSERT_EQ((move.Raw() >> 6) & 0x3F, 63); - ASSERT_EQ((move.Raw() >> 14) & 0x3, 0); -} - -TEST(HashFunction, Consistency) { - const Move move1(8, 16); - const Move move2(8, 16); - const Move move3(8, 24); - - Move::Hash hasher; - ASSERT_EQ(hasher(move1), hasher(move2)); - ASSERT_NE(hasher(move1), hasher(move3)); + const Move move(63, 0); + ASSERT_EQ(move.Raw() & 0x3F, 0); + ASSERT_EQ((move.Raw() >> 6) & 0x3F, 63); + ASSERT_EQ((move.Raw() >> 14) & 0x3, 0); } TEST(SizeAndAlignment, OptimalLayout) { - ASSERT_EQ(sizeof(Move), 2); - ASSERT_EQ(alignof(Move), 2); - ASSERT_TRUE(std::is_trivially_copyable_v); - ASSERT_TRUE(std::is_standard_layout_v); + ASSERT_EQ(sizeof(Move), 2); + ASSERT_EQ(alignof(Move), 2); + ASSERT_TRUE(std::is_trivially_copyable_v); + ASSERT_TRUE(std::is_standard_layout_v); } TEST(BooleanConversion, ValidityCheck) { - const Move valid_move(8, 16); - const Move none_move = Move::None(); - - ASSERT_TRUE(static_cast(valid_move)); - ASSERT_FALSE(static_cast(none_move)); + const Move valid_move(8, 16); + const Move none_move = Move::None(); + + ASSERT_TRUE(static_cast(valid_move)); + ASSERT_FALSE(static_cast(none_move)); } TEST(RawDataAccess, RoundTrip) { - const Move move(8, 16); - const std::uint16_t raw = move.Raw(); - const Move reconstructed(raw); - - ASSERT_EQ(move, reconstructed); - ASSERT_EQ(move.From(), reconstructed.From()); - ASSERT_EQ(move.To(), reconstructed.To()); + const Move move(8, 16); + const std::uint16_t raw = move.Raw(); + const Move reconstructed(raw); + + ASSERT_EQ(move, reconstructed); + ASSERT_EQ(move.From(), reconstructed.From()); + ASSERT_EQ(move.To(), reconstructed.To()); } TEST(AllSquares, ValidEncoding) { - for (BitIndex from = 0; from < 64; ++from) { - for (BitIndex to = 0; to < 64; ++to) { - if (from != to) { - const Move move(from, to); - ASSERT_EQ(move.From(), from); - ASSERT_EQ(move.To(), to); - ASSERT_TRUE(move.IsValid()); - } - } + for (BitIndex from = 0; from < 64; ++from) { + for (BitIndex to = 0; to < 64; ++to) { + if (from != to) { + const Move move(from, to); + ASSERT_EQ(move.From(), from); + ASSERT_EQ(move.To(), to); + ASSERT_TRUE(move.IsValid()); + } } + } } } // namespace CompactMoveTests \ No newline at end of file diff --git a/Tests/MoveGeneratorTests.cpp b/Tests/MoveGeneratorTests.cpp index b6678be..97b3aaf 100644 --- a/Tests/MoveGeneratorTests.cpp +++ b/Tests/MoveGeneratorTests.cpp @@ -33,7 +33,7 @@ struct GameInfo { static auto move_generator = MoveGenerator{}; const auto moves = - move_generator.GenerateMoves(position); + move_generator.GenerateMoves(position); if (depth == 1) { for (const auto &move : moves) { @@ -67,6 +67,10 @@ struct GenTestCase { std::vector infos; }; +std::ostream &operator<<(std::ostream &os, const GenTestCase &test_case) { + return os << "FEN: " << test_case.fen; +} + class GenerateMovesTest : public testing::TestWithParam { protected: void SetUp() override { @@ -104,10 +108,15 @@ TEST_P(GenerateMovesTest, Perft) { EXPECT_EQ(position, GetPosition()); EXPECT_EQ(position.GetHash(), GetPosition().GetHash()); - if (possible_games_answer) + if (possible_games_answer) { EXPECT_EQ(*possible_games, *possible_games_answer); - if (en_croissants_answer) EXPECT_EQ(*en_croissants, *en_croissants_answer); - if (castlings_answer) EXPECT_EQ(*castlings, *castlings_answer); + } + if (en_croissants_answer) { + EXPECT_EQ(*en_croissants, *en_croissants_answer); + } + if (castlings_answer) { + EXPECT_EQ(*castlings, *castlings_answer); + } } } diff --git a/Tests/MoveTests.cpp b/Tests/MoveTests.cpp new file mode 100644 index 0000000..22be6ed --- /dev/null +++ b/Tests/MoveTests.cpp @@ -0,0 +1,382 @@ +#include +#include "../Chess/Move.h" +#include +#include + +using namespace SimpleChessEngine; + +// Test basic Move functionality +TEST(MoveTest, BasicConstruction) { + Move move1(0, 1); + EXPECT_EQ(move1.From(), 0); + EXPECT_EQ(move1.To(), 1); + EXPECT_TRUE(move1.IsValid()); + + Move move2 = Move::None(); + EXPECT_FALSE(move2.IsValid()); + + Move move3 = Move::Null(); + EXPECT_FALSE(move3.IsValid()); +} + +TEST(MoveTest, MoveTypes) { + auto normal = Move(0, 1); + EXPECT_TRUE(normal.IsNormal()); + EXPECT_FALSE(normal.IsPromotion()); + EXPECT_FALSE(normal.IsEnPassant()); + EXPECT_FALSE(normal.IsCastling()); + + auto promotion = Move::Make(0, 8, Piece::kQueen); + EXPECT_TRUE(promotion.IsPromotion()); + EXPECT_EQ(promotion.PromotionPiece(), Piece::kQueen); + + auto enpassant = Move::Make(0, 8); + EXPECT_TRUE(enpassant.IsEnPassant()); + + auto castling = Move::Make(4, 6); + EXPECT_TRUE(castling.IsCastling()); +} + +// Test TypedMove value types +TEST(TypedMoveTest, ValueTypeConstruction) { + Move move(0, 1); + + PseudoLegalMove pseudo = MakeTypedMove(move); + EXPECT_EQ(pseudo.From(), 0); + EXPECT_EQ(pseudo.To(), 1); + + LegalMove legal = MakeTypedMove(move); + EXPECT_EQ(legal.From(), 0); + EXPECT_EQ(legal.To(), 1); +} + +TEST(TypedMoveTest, ReferenceTypeConstruction) { + Move move(0, 1); + + PseudoLegalMoveRef pseudo_ref = MakeTypedMove(move); + EXPECT_EQ(pseudo_ref.From(), 0); + EXPECT_EQ(pseudo_ref.To(), 1); + + PseudoLegalMoveConstRef pseudo_const_ref = MakeTypedMove(move); + EXPECT_EQ(pseudo_const_ref.From(), 0); + EXPECT_EQ(pseudo_const_ref.To(), 1); + + LegalMoveRef legal_ref = MakeTypedMove(move); + EXPECT_EQ(legal_ref.From(), 0); + EXPECT_EQ(legal_ref.To(), 1); + + LegalMoveConstRef legal_const_ref = MakeTypedMove(move); + EXPECT_EQ(legal_const_ref.From(), 0); + EXPECT_EQ(legal_const_ref.To(), 1); +} + +// Test conversions between typed moves +TEST(TypedMoveTest, ValueToReferenceConversion) { + PseudoLegalMove pseudo_value = MakeTypedMove(Move(0, 1)); + + // Value to non-const reference + PseudoLegalMoveRef pseudo_ref = pseudo_value; + EXPECT_EQ(pseudo_ref.From(), 0); + EXPECT_EQ(pseudo_ref.To(), 1); + + // Value to const reference + PseudoLegalMoveConstRef pseudo_const_ref = pseudo_value; + EXPECT_EQ(pseudo_const_ref.From(), 0); + EXPECT_EQ(pseudo_const_ref.To(), 1); + + LegalMove legal_value = MakeTypedMove(Move(2, 3)); + + // Value to non-const reference + LegalMoveRef legal_ref = legal_value; + EXPECT_EQ(legal_ref.From(), 2); + EXPECT_EQ(legal_ref.To(), 3); + + // Value to const reference + LegalMoveConstRef legal_const_ref = legal_value; + EXPECT_EQ(legal_const_ref.From(), 2); + EXPECT_EQ(legal_const_ref.To(), 3); +} + +TEST(TypedMoveTest, ReferenceToValueConversion) { + Move move(0, 1); + PseudoLegalMoveRef pseudo_ref = MakeTypedMove(move); + + // Reference to value + PseudoLegalMove pseudo_value = pseudo_ref; + EXPECT_EQ(pseudo_value.From(), 0); + EXPECT_EQ(pseudo_value.To(), 1); + + LegalMoveRef legal_ref = MakeTypedMove(move); + + // Reference to value + LegalMove legal_value = legal_ref; + EXPECT_EQ(legal_value.From(), 0); + EXPECT_EQ(legal_value.To(), 1); +} + +TEST(TypedMoveTest, TagConversion) { + Move move(0, 1); + LegalMove legal = MakeTypedMove(move); + + // LegalMove can convert to PseudoLegalMove + PseudoLegalMove pseudo = legal; + EXPECT_EQ(pseudo.From(), 0); + EXPECT_EQ(pseudo.To(), 1); + + // Both can convert to Move + Move move1 = legal; + Move move2 = pseudo; + EXPECT_EQ(move1, move); + EXPECT_EQ(move2, move); +} + +// Test MoveList +TEST(MoveListTest, BasicOperations) { + MoveList moves; + + EXPECT_TRUE(moves.empty()); + EXPECT_EQ(moves.size(), 0); + + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + moves.push_back(MakeTypedMove(Move(4, 5))); + + EXPECT_FALSE(moves.empty()); + EXPECT_EQ(moves.size(), 3); + + EXPECT_EQ(moves[0].From(), 0); + EXPECT_EQ(moves[0].To(), 1); + EXPECT_EQ(moves[1].From(), 2); + EXPECT_EQ(moves[1].To(), 3); + EXPECT_EQ(moves[2].From(), 4); + EXPECT_EQ(moves[2].To(), 5); +} + +TEST(MoveListTest, FrontBack) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + moves.push_back(MakeTypedMove(Move(4, 5))); + + EXPECT_EQ(moves.front().From(), 0); + EXPECT_EQ(moves.front().To(), 1); + EXPECT_EQ(moves.back().From(), 4); + EXPECT_EQ(moves.back().To(), 5); +} + +TEST(MoveListTest, IteratorBasics) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + moves.push_back(MakeTypedMove(Move(4, 5))); + + auto it = moves.begin(); + EXPECT_EQ(it->From(), 0); + EXPECT_EQ(it->To(), 1); + + ++it; + EXPECT_EQ(it->From(), 2); + EXPECT_EQ(it->To(), 3); + + ++it; + EXPECT_EQ(it->From(), 4); + EXPECT_EQ(it->To(), 5); + + ++it; + EXPECT_EQ(it, moves.end()); +} + +TEST(MoveListTest, ConstIterator) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + + const auto& const_moves = moves; + + auto it = const_moves.begin(); + EXPECT_EQ(it->From(), 0); + EXPECT_EQ(it->To(), 1); + + ++it; + EXPECT_EQ(it->From(), 2); + EXPECT_EQ(it->To(), 3); + + ++it; + EXPECT_EQ(it, const_moves.end()); +} + +TEST(MoveListTest, RangeBasedFor) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + moves.push_back(MakeTypedMove(Move(4, 5))); + + int count = 0; + for (const auto& move : moves) { + EXPECT_TRUE(move.IsValid()); + EXPECT_EQ(move.From(), count * 2); + EXPECT_EQ(move.To(), count * 2 + 1); + ++count; + } + EXPECT_EQ(count, 3); +} + +TEST(MoveListTest, IteratorArithmetic) { + MoveList moves; + for (int i = 0; i < 10; ++i) { + moves.push_back(MakeTypedMove(Move(i, i + 1))); + } + + auto it = moves.begin(); + + // Addition + auto it2 = it + 5; + EXPECT_EQ(it2->From(), 5); + + // Subtraction + auto it3 = it2 - 3; + EXPECT_EQ(it3->From(), 2); + + // Difference + EXPECT_EQ(it2 - it, 5); + EXPECT_EQ(it3 - it, 2); + + // Subscript + EXPECT_EQ(it[0].From(), 0); + EXPECT_EQ(it[5].From(), 5); + EXPECT_EQ(it[9].From(), 9); +} + +TEST(MoveListTest, IteratorComparison) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + moves.push_back(MakeTypedMove(Move(4, 5))); + + auto it1 = moves.begin(); + auto it2 = moves.begin(); + auto it3 = moves.begin() + 1; + auto end = moves.end(); + + EXPECT_TRUE(it1 == it2); + EXPECT_FALSE(it1 != it2); + EXPECT_TRUE(it1 != it3); + EXPECT_FALSE(it1 == it3); + + EXPECT_TRUE(it1 < it3); + EXPECT_TRUE(it1 <= it3); + EXPECT_FALSE(it1 > it3); + EXPECT_FALSE(it1 >= it3); + + EXPECT_TRUE(it3 > it1); + EXPECT_TRUE(it3 >= it1); + EXPECT_FALSE(it3 < it1); + EXPECT_FALSE(it3 <= it1); + + EXPECT_TRUE(it1 < end); + EXPECT_TRUE(it3 < end); +} + +TEST(MoveListTest, STLAlgorithms) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(5, 6))); + moves.push_back(MakeTypedMove(Move(1, 2))); + moves.push_back(MakeTypedMove(Move(3, 4))); + moves.push_back(MakeTypedMove(Move(7, 8))); + + // Find + auto it = std::find_if(moves.begin(), moves.end(), + [](const auto& m) { return m.From() == 3; }); + EXPECT_NE(it, moves.end()); + EXPECT_EQ(it->From(), 3); + EXPECT_EQ(it->To(), 4); + + // Count + int count = std::count_if(moves.begin(), moves.end(), + [](const auto& m) { return m.From() > 2; }); + EXPECT_EQ(count, 3); + + // Sort - need to work with the underlying data + auto& data = const_cast&>(moves.data()); + std::sort(data.begin(), data.end(), + [](const Move& a, const Move& b) { return a.From() < b.From(); }); + + EXPECT_EQ(moves[0].From(), 1); + EXPECT_EQ(moves[1].From(), 3); + EXPECT_EQ(moves[2].From(), 5); + EXPECT_EQ(moves[3].From(), 7); +} + +TEST(MoveListTest, EmplaceBack) { + MoveList moves; + + auto ref = moves.emplace_back(0, 1); + EXPECT_EQ(ref.From(), 0); + EXPECT_EQ(ref.To(), 1); + EXPECT_EQ(moves.size(), 1); + + moves.emplace_back(2, 3); + EXPECT_EQ(moves.size(), 2); + EXPECT_EQ(moves[1].From(), 2); + EXPECT_EQ(moves[1].To(), 3); +} + +TEST(MoveListTest, Clear) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + + EXPECT_EQ(moves.size(), 2); + + moves.clear(); + EXPECT_EQ(moves.size(), 0); + EXPECT_TRUE(moves.empty()); +} + +TEST(MoveListTest, Reserve) { + MoveList moves; + + moves.reserve(100); + EXPECT_GE(moves.capacity(), 100); + EXPECT_EQ(moves.size(), 0); + + for (int i = 0; i < 50; ++i) { + moves.push_back(MakeTypedMove(Move(i, i + 1))); + } + + EXPECT_EQ(moves.size(), 50); + EXPECT_GE(moves.capacity(), 100); +} + +// Test iterator value_type +TEST(MoveListTest, IteratorValueType) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + + using IteratorType = decltype(moves.begin()); + using ValueType = typename IteratorType::value_type; + + // value_type should be Move, not TypedMove + static_assert(std::is_same_v, + "Iterator value_type should be Move"); +} + +// Test that reference types work correctly +TEST(MoveListTest, ReferenceTypeCorrectness) { + MoveList moves; + moves.push_back(MakeTypedMove(Move(0, 1))); + moves.push_back(MakeTypedMove(Move(2, 3))); + + // Non-const iterator should return PseudoLegalMoveRef + auto it = moves.begin(); + using RefType = decltype(*it); + static_assert(std::is_same_v, + "Non-const iterator should return PseudoLegalMoveRef"); + + // Const iterator should return PseudoLegalMoveConstRef + const auto& const_moves = moves; + auto const_it = const_moves.begin(); + using ConstRefType = decltype(*const_it); + static_assert(std::is_same_v, + "Const iterator should return PseudoLegalMoveConstRef"); +} \ No newline at end of file diff --git a/Tests/PositionTests.cpp b/Tests/PositionTests.cpp index c0089bc..73fcb9b 100644 --- a/Tests/PositionTests.cpp +++ b/Tests/PositionTests.cpp @@ -13,7 +13,7 @@ TEST(DoMove, DoAndUndoEqualZero) { Position pos = start_pos; for (const auto moves = - MoveGenerator{}.GenerateMoves(pos); + MoveGenerator{}.GenerateMoves(pos); const auto &move : moves) { const auto irreversible_data = pos.GetIrreversibleData(); pos.DoMove(move); @@ -29,17 +29,17 @@ TEST(SomeMoves, DifferentHash) { std::vector moves = {"b1c3", "d7d5", "a1b1"}; std::set hashes = {pos.GetHash()}; for (const auto &move_str : moves) { - auto move = MoveFactory{}(pos, move_str); - pos.DoMove(move); + auto move = MoveCast(MoveFactory{}(pos, move_str), pos); + ASSERT_TRUE(move); + pos.DoMove(*move); ASSERT_TRUE(hashes.insert(pos.GetHash()).second); } } Position DoMoves(Position pos, const std::vector &moves) { for (const auto &move_str : moves) { - auto move = MoveFactory{}(pos, move_str); - - pos.DoMove(move); + auto move = MoveCast(MoveFactory{}(pos, move_str), pos); + pos.DoMove(*move); } return pos; }