33#include < array>
44#include < cmath>
55
6+ #define INPUT_SIZE 768
7+ #define FEATURE_SIZE 32
8+ #define LAYER1_SIZE 128
9+
610std::pair<bool , Value> checkGameStatus (Position& board) {
711 // Generate legal moves to validate checkmate or stalemate
812 Movelist moves = board.legalMoves ();
@@ -27,124 +31,156 @@ std::pair<bool, Value> checkGameStatus(Position& board) {
2731
2832namespace {
2933
30- class EvaluatorNet {
34+ struct alignas (32 ) Accumulator {
35+ int white[FEATURE_SIZE];
36+ int black[FEATURE_SIZE];
37+ };
38+
39+ std::pair<size_t , size_t > getFeatureIndices (const Color color, const PieceType pt, Square sq) {
40+ const bool isWhite = color == WHITE;
41+ const size_t whiteIndex = ((int ) (!isWhite) * 6 + (int ) pt) * 64 + sq.index ();
42+ const size_t blackIndex = ((int ) (isWhite) * 6 + (int ) pt) * 64 + sq.flip ().index ();
43+ return {whiteIndex, blackIndex};
44+ }
45+
46+ class NNUEState {
3147private:
32- const nnue::Weight& w_;
48+ std::vector<Accumulator> accumulatorStack;
49+ const nnue::Weight& w;
3350
3451public:
35- EvaluatorNet () : w_(*nnue::weight) { }
52+ inline Accumulator& curr () { return accumulatorStack. back (); }
3653
37- int operator ()(const int8_t * __restrict__ x1, const int8_t * __restrict__ x2) const {
38- // Set first layer bias
39- int32_t accumulator1[32 ] = {0 };
40- for (int i = 0 ; i < 32 ; ++i) {
41- accumulator1[i] = w_.fc1_bias [i]; // auto-vectorizable
42- }
43- int32_t accumulator2[32 ] = {0 };
44- for (int i = 0 ; i < 32 ; ++i) {
45- accumulator2[i] = w_.fc1_bias [i];
46- }
47- // Accumulate with weight
48- for (int i = 0 ; i < 768 ; ++i) {
49- for (int j = 0 ; j < 32 ; ++j) {
50- accumulator1[j] += w_.fc1_weight [i * 32 + j] * x1[i];
51- }
52- }
53- for (int i = 0 ; i < 768 ; ++i) {
54- for (int j = 0 ; j < 32 ; ++j) {
55- accumulator2[j] += w_.fc1_weight [i * 32 + j] * x2[i];
56- }
57- }
58- // Clamp to 0 and 32767
59- for (int i = 0 ; i < 32 ; ++i) {
60- accumulator1[i] = std::clamp (accumulator1[i], 0 , 32767 );
54+ public:
55+ NNUEState () : w(*nnue::weight) {}
56+
57+ void push () {
58+ Accumulator copy = curr ();
59+ accumulatorStack.push_back (copy);
60+ }
61+ void pop () { accumulatorStack.pop_back (); }
62+
63+ void reset (const Position& pos) {
64+ // Create a new accumulator
65+ Accumulator accum;
66+ // Initialize with bias
67+ for (int i = 0 ; i < FEATURE_SIZE; ++i) {
68+ accum.white [i] = w.fc1_bias [i];
69+ accum.black [i] = w.fc1_bias [i];
70+ }
71+ // Clear the stack and push the accumulator
72+ accumulatorStack.clear ();
73+ accumulatorStack.push_back (std::move (accum));
74+ // Call the update functions
75+ Bitboard occ = pos.occ ();
76+ while (occ) {
77+ Square sq = occ.pop ();
78+ Piece p = pos.at (sq);
79+ update<true >(p, sq);
6180 }
62- for (int i = 0 ; i < 32 ; ++i) {
63- accumulator2[i] = std::clamp (accumulator2[i], 0 , 32767 );
81+ }
82+
83+ template <bool activate> void update (const Piece piece, const Square square) {
84+ update<activate>(piece.color (), piece.type (), square);
85+ }
86+
87+ template <bool activate> void update (const Color color, const PieceType pt, const Square sq) {
88+ const auto [wi, bi] = getFeatureIndices (color, pt, sq);
89+ constexpr int multiplier = (activate ? 1 : -1 );
90+ for (int i = 0 ; i < FEATURE_SIZE; ++i) {
91+ curr ().white [i] += w.fc1_weight [wi * FEATURE_SIZE + i] * multiplier;
6492 }
65- // Also compute features with square
66- int32_t acc1_sqr[32 ] = {0 };
67- for (int i = 0 ; i < 32 ; ++i) {
68- acc1_sqr[i] = accumulator1[i] * accumulator1[i];
93+ for (int i = 0 ; i < FEATURE_SIZE; ++i) {
94+ curr ().black [i] += w.fc1_weight [bi * FEATURE_SIZE + i] * multiplier;
6995 }
96+ }
97+
98+ int evaluate (const Color color) {
99+ const int * input1 = ((color == WHITE) ? curr ().white : curr ().black );
100+ const int * input2 = ((color == WHITE) ? curr ().black : curr ().white );
101+ // Clipped ReLU activation
102+ int v1[FEATURE_SIZE], v2[FEATURE_SIZE];
70103 for (int i = 0 ; i < 32 ; ++i) {
71- acc1_sqr[i] >>= 15 ;
104+ v1[i] = std::clamp (input1[i], 0 , 32767 );
105+ v2[i] = std::clamp (input2[i], 0 , 32767 );
72106 }
73- int32_t acc2_sqr[32 ] = {0 };
107+ // Clipped square activation
108+ int v1s[FEATURE_SIZE], v2s[FEATURE_SIZE];
74109 for (int i = 0 ; i < 32 ; ++i) {
75- acc2_sqr[i] = accumulator2[i] * accumulator2[i];
110+ v1s[i] = v1[i] * v1[i];
111+ v2s[i] = v2[i] * v2[i];
76112 }
77113 for (int i = 0 ; i < 32 ; ++i) {
78- acc2_sqr[i] >>= 15 ;
114+ v1s[i] >>= 15 ;
115+ v2s[i] >>= 15 ;
79116 }
80-
81- // Values are ready to pass through dense layer.
82- int32_t acc = w_.fc2_bias ; // set bias
83- int32_t temp = 0 ;
117+ // Pass through second layer
118+ int temp[4 ] = {0 };
84119 for (int i = 0 ; i < 32 ; ++i) {
85- temp += accumulator1 [i] * w_ .fc2_weight [i];
120+ temp[ 0 ] += v1 [i] * w .fc2_weight [i];
86121 }
87- acc += temp / 127 ;
88- temp = 0 ;
89122 for (int i = 0 ; i < 32 ; ++i) {
90- temp += acc1_sqr [i] * w_ .fc2_weight [i + 32 ];
123+ temp[ 1 ] += v1s [i] * w .fc2_weight [i + FEATURE_SIZE ];
91124 }
92- acc += temp / 127 ;
93- temp = 0 ;
94125 for (int i = 0 ; i < 32 ; ++i) {
95- temp += accumulator2 [i] * w_ .fc2_weight [i + 64 ];
126+ temp[ 2 ] += v2 [i] * w .fc2_weight [i + FEATURE_SIZE * 2 ];
96127 }
97- acc += temp / 127 ;
98- temp = 0 ;
99128 for (int i = 0 ; i < 32 ; ++i) {
100- temp += acc2_sqr [i] * w_ .fc2_weight [i + 96 ];
129+ temp[ 3 ] += v2s [i] * w .fc2_weight [i + FEATURE_SIZE * 3 ];
101130 }
102- acc += temp / 127 ;
103-
104- return acc / 152 ;
131+ // Accumulate
132+ int y = w.fc2_bias + temp[0 ] / 127 + temp[1 ] / 127 + temp[2 ] / 127 + temp[3 ] / 127 ;
133+ y = y / 152 ;
134+ return y;
105135 }
106136};
107137
108- void getInputRepresentationFor (const Board& pos, int8_t * v1, int8_t * v2) {
109- int8_t white[768 ] = {0 };
110- int8_t black[768 ] = {0 };
138+ // global instance
139+ NNUEState gNNUE ;
111140
112- const bool stm_white = (pos.sideToMove () == WHITE);
113-
114- auto scan = [&](Bitboard bb, bool isWhite, int idx) {
115- while (bb) {
116- int sq = bb.pop ();
117- int whiteIndex = ((int ) (!isWhite) * 6 + idx) * 64 + sq;
118- white[whiteIndex] = 1 ;
119- int blackIndex = ((int ) (isWhite) * 6 + idx) * 64 + sq;
120- black[blackIndex] = 1 ;
121- }
122- };
141+ } // namespace
123142
124- for (int p_index = 0 ; p_index < 6 ; ++p_index) {
125- PieceType pt = (PieceType::underlying) p_index;
126- scan (pos.pieces (pt, WHITE), true , p_index);
127- scan (pos.pieces (pt, BLACK), false , p_index);
128- }
143+ /* *
144+ * Main evaluation function.
145+ */
146+ Value evaluate (Position& pos) {
147+ gNNUE .reset (pos);
148+ return gNNUE .evaluate (pos.sideToMove ());
149+ }
129150
130- if (stm_white) {
131- memcpy (v1, white, 768 );
132- memcpy (v2, black, 768 );
133- } else {
134- memcpy (v2, white, 768 );
135- memcpy (v1, black, 768 );
136- }
151+ /* *
152+ * Update evaluator state. This tells the net to incrementally
153+ * update since you make some move.
154+ */
155+ void updateEvaluatorState (const Position& pos, const Move& move) {
156+ // const Piece pFrom = pos.at(move.from());
157+ // const Square sFrom = move.from();
158+ // const Piece pTo = pos.at(move.to());
159+ // const Square sTo = move.to();
137160}
138161
139- } // namespace
162+ /* *
163+ * This tells the net to refresh all accumulators.
164+ */
165+ void updateEvaluatorState (const Position& pos) {
166+ // gNNUE.reset(pos);
167+ }
140168
141169/* *
142- * Main evaluation function .
170+ * This tells the net that you have undone a move .
143171 */
144- Value evaluate (Position& pos) {
145- static EvaluatorNet net;
146- int8_t vec1[768 ];
147- int8_t vec2[768 ];
148- getInputRepresentationFor (pos, vec1, vec2);
149- return net (vec1, vec2);
172+ void updateEvaluatorState () {
173+ // gNNUE.pop();
150174}
175+
176+ /* *
177+ * go depth 8
178+ info string tc 0 0
179+ info depth 1 score cp 32 nodes 21 seldepth 1 time 7 pv d2d4
180+ info depth 2 score cp 39 nodes 78 seldepth 2 time 10 pv d2d4
181+ info depth 3 score cp 21 nodes 693 seldepth 8 time 15 pv e2e4
182+ info depth 4 score cp 35 nodes 2027 seldepth 8 time 20 pv d2d4
183+ info depth 5 score cp 37 nodes 9006 seldepth 10 time 46 pv d2d4
184+ info depth 6 score cp 40 nodes 22119 seldepth 15 time 125 pv d2d4
185+ info depth 7 score cp 29 nodes 68574 seldepth 15 time 262 pv d2d4
186+ */
0 commit comments