Skip to content

Commit 7ef7295

Browse files
committed
feat: cleaner code
1 parent 062965d commit 7ef7295

File tree

5 files changed

+148
-95
lines changed

5 files changed

+148
-95
lines changed

gem.bin

0 Bytes
Binary file not shown.

src/eval.cpp

Lines changed: 124 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include <array>
44
#include <cmath>
55

6+
#define INPUT_SIZE 768
7+
#define FEATURE_SIZE 32
8+
#define LAYER1_SIZE 128
9+
610
std::pair<bool, Value> checkGameStatus(Position& board) {
711
// Generate legal moves to validate checkmate or stalemate
812
Movelist moves = board.legalMoves();
@@ -27,124 +31,156 @@ std::pair<bool, Value> checkGameStatus(Position& board) {
2731

2832
namespace {
2933

30-
class EvaluatorNet {
34+
struct alignas(32) Accumulator {
35+
int white[FEATURE_SIZE];
36+
int black[FEATURE_SIZE];
37+
};
38+
39+
std::pair<size_t, size_t> getFeatureIndices(const Color color, const PieceType pt, Square sq) {
40+
const bool isWhite = color == WHITE;
41+
const size_t whiteIndex = ((int) (!isWhite) * 6 + (int) pt) * 64 + sq.index();
42+
const size_t blackIndex = ((int) (isWhite) * 6 + (int) pt) * 64 + sq.flip().index();
43+
return {whiteIndex, blackIndex};
44+
}
45+
46+
class NNUEState {
3147
private:
32-
const nnue::Weight& w_;
48+
std::vector<Accumulator> accumulatorStack;
49+
const nnue::Weight& w;
3350

3451
public:
35-
EvaluatorNet() : w_(*nnue::weight) {}
52+
inline Accumulator& curr() { return accumulatorStack.back(); }
3653

37-
int operator()(const int8_t* __restrict__ x1, const int8_t* __restrict__ x2) const {
38-
// Set first layer bias
39-
int32_t accumulator1[32] = {0};
40-
for (int i = 0; i < 32; ++i) {
41-
accumulator1[i] = w_.fc1_bias[i]; // auto-vectorizable
42-
}
43-
int32_t accumulator2[32] = {0};
44-
for (int i = 0; i < 32; ++i) {
45-
accumulator2[i] = w_.fc1_bias[i];
46-
}
47-
// Accumulate with weight
48-
for (int i = 0; i < 768; ++i) {
49-
for (int j = 0; j < 32; ++j) {
50-
accumulator1[j] += w_.fc1_weight[i * 32 + j] * x1[i];
51-
}
52-
}
53-
for (int i = 0; i < 768; ++i) {
54-
for (int j = 0; j < 32; ++j) {
55-
accumulator2[j] += w_.fc1_weight[i * 32 + j] * x2[i];
56-
}
57-
}
58-
// Clamp to 0 and 32767
59-
for (int i = 0; i < 32; ++i) {
60-
accumulator1[i] = std::clamp(accumulator1[i], 0, 32767);
54+
public:
55+
NNUEState() : w(*nnue::weight) {}
56+
57+
void push() {
58+
Accumulator copy = curr();
59+
accumulatorStack.push_back(copy);
60+
}
61+
void pop() { accumulatorStack.pop_back(); }
62+
63+
void reset(const Position& pos) {
64+
// Create a new accumulator
65+
Accumulator accum;
66+
// Initialize with bias
67+
for (int i = 0; i < FEATURE_SIZE; ++i) {
68+
accum.white[i] = w.fc1_bias[i];
69+
accum.black[i] = w.fc1_bias[i];
70+
}
71+
// Clear the stack and push the accumulator
72+
accumulatorStack.clear();
73+
accumulatorStack.push_back(std::move(accum));
74+
// Call the update functions
75+
Bitboard occ = pos.occ();
76+
while (occ) {
77+
Square sq = occ.pop();
78+
Piece p = pos.at(sq);
79+
update<true>(p, sq);
6180
}
62-
for (int i = 0; i < 32; ++i) {
63-
accumulator2[i] = std::clamp(accumulator2[i], 0, 32767);
81+
}
82+
83+
template <bool activate> void update(const Piece piece, const Square square) {
84+
update<activate>(piece.color(), piece.type(), square);
85+
}
86+
87+
template <bool activate> void update(const Color color, const PieceType pt, const Square sq) {
88+
const auto [wi, bi] = getFeatureIndices(color, pt, sq);
89+
constexpr int multiplier = (activate ? 1 : -1);
90+
for (int i = 0; i < FEATURE_SIZE; ++i) {
91+
curr().white[i] += w.fc1_weight[wi * FEATURE_SIZE + i] * multiplier;
6492
}
65-
// Also compute features with square
66-
int32_t acc1_sqr[32] = {0};
67-
for (int i = 0; i < 32; ++i) {
68-
acc1_sqr[i] = accumulator1[i] * accumulator1[i];
93+
for (int i = 0; i < FEATURE_SIZE; ++i) {
94+
curr().black[i] += w.fc1_weight[bi * FEATURE_SIZE + i] * multiplier;
6995
}
96+
}
97+
98+
int evaluate(const Color color) {
99+
const int* input1 = ((color == WHITE) ? curr().white : curr().black);
100+
const int* input2 = ((color == WHITE) ? curr().black : curr().white);
101+
// Clipped ReLU activation
102+
int v1[FEATURE_SIZE], v2[FEATURE_SIZE];
70103
for (int i = 0; i < 32; ++i) {
71-
acc1_sqr[i] >>= 15;
104+
v1[i] = std::clamp(input1[i], 0, 32767);
105+
v2[i] = std::clamp(input2[i], 0, 32767);
72106
}
73-
int32_t acc2_sqr[32] = {0};
107+
// Clipped square activation
108+
int v1s[FEATURE_SIZE], v2s[FEATURE_SIZE];
74109
for (int i = 0; i < 32; ++i) {
75-
acc2_sqr[i] = accumulator2[i] * accumulator2[i];
110+
v1s[i] = v1[i] * v1[i];
111+
v2s[i] = v2[i] * v2[i];
76112
}
77113
for (int i = 0; i < 32; ++i) {
78-
acc2_sqr[i] >>= 15;
114+
v1s[i] >>= 15;
115+
v2s[i] >>= 15;
79116
}
80-
81-
// Values are ready to pass through dense layer.
82-
int32_t acc = w_.fc2_bias; // set bias
83-
int32_t temp = 0;
117+
// Pass through second layer
118+
int temp[4] = {0};
84119
for (int i = 0; i < 32; ++i) {
85-
temp += accumulator1[i] * w_.fc2_weight[i];
120+
temp[0] += v1[i] * w.fc2_weight[i];
86121
}
87-
acc += temp / 127;
88-
temp = 0;
89122
for (int i = 0; i < 32; ++i) {
90-
temp += acc1_sqr[i] * w_.fc2_weight[i + 32];
123+
temp[1] += v1s[i] * w.fc2_weight[i + FEATURE_SIZE];
91124
}
92-
acc += temp / 127;
93-
temp = 0;
94125
for (int i = 0; i < 32; ++i) {
95-
temp += accumulator2[i] * w_.fc2_weight[i + 64];
126+
temp[2] += v2[i] * w.fc2_weight[i + FEATURE_SIZE * 2];
96127
}
97-
acc += temp / 127;
98-
temp = 0;
99128
for (int i = 0; i < 32; ++i) {
100-
temp += acc2_sqr[i] * w_.fc2_weight[i + 96];
129+
temp[3] += v2s[i] * w.fc2_weight[i + FEATURE_SIZE * 3];
101130
}
102-
acc += temp / 127;
103-
104-
return acc / 152;
131+
// Accumulate
132+
int y = w.fc2_bias + temp[0] / 127 + temp[1] / 127 + temp[2] / 127 + temp[3] / 127;
133+
y = y / 152;
134+
return y;
105135
}
106136
};
107137

108-
void getInputRepresentationFor(const Board& pos, int8_t* v1, int8_t* v2) {
109-
int8_t white[768] = {0};
110-
int8_t black[768] = {0};
138+
// global instance
139+
NNUEState gNNUE;
111140

112-
const bool stm_white = (pos.sideToMove() == WHITE);
113-
114-
auto scan = [&](Bitboard bb, bool isWhite, int idx) {
115-
while (bb) {
116-
int sq = bb.pop();
117-
int whiteIndex = ((int) (!isWhite) * 6 + idx) * 64 + sq;
118-
white[whiteIndex] = 1;
119-
int blackIndex = ((int) (isWhite) * 6 + idx) * 64 + sq;
120-
black[blackIndex] = 1;
121-
}
122-
};
141+
} // namespace
123142

124-
for (int p_index = 0; p_index < 6; ++p_index) {
125-
PieceType pt = (PieceType::underlying) p_index;
126-
scan(pos.pieces(pt, WHITE), true, p_index);
127-
scan(pos.pieces(pt, BLACK), false, p_index);
128-
}
143+
/**
144+
* Main evaluation function.
145+
*/
146+
Value evaluate(Position& pos) {
147+
gNNUE.reset(pos);
148+
return gNNUE.evaluate(pos.sideToMove());
149+
}
129150

130-
if (stm_white) {
131-
memcpy(v1, white, 768);
132-
memcpy(v2, black, 768);
133-
} else {
134-
memcpy(v2, white, 768);
135-
memcpy(v1, black, 768);
136-
}
151+
/**
152+
* Update evaluator state. This tells the net to incrementally
153+
* update since you make some move.
154+
*/
155+
void updateEvaluatorState(const Position& pos, const Move& move) {
156+
// const Piece pFrom = pos.at(move.from());
157+
// const Square sFrom = move.from();
158+
// const Piece pTo = pos.at(move.to());
159+
// const Square sTo = move.to();
137160
}
138161

139-
} // namespace
162+
/**
163+
* This tells the net to refresh all accumulators.
164+
*/
165+
void updateEvaluatorState(const Position& pos) {
166+
// gNNUE.reset(pos);
167+
}
140168

141169
/**
142-
* Main evaluation function.
170+
* This tells the net that you have undone a move.
143171
*/
144-
Value evaluate(Position& pos) {
145-
static EvaluatorNet net;
146-
int8_t vec1[768];
147-
int8_t vec2[768];
148-
getInputRepresentationFor(pos, vec1, vec2);
149-
return net(vec1, vec2);
172+
void updateEvaluatorState() {
173+
// gNNUE.pop();
150174
}
175+
176+
/**
177+
* go depth 8
178+
info string tc 0 0
179+
info depth 1 score cp 32 nodes 21 seldepth 1 time 7 pv d2d4
180+
info depth 2 score cp 39 nodes 78 seldepth 2 time 10 pv d2d4
181+
info depth 3 score cp 21 nodes 693 seldepth 8 time 15 pv e2e4
182+
info depth 4 score cp 35 nodes 2027 seldepth 8 time 20 pv d2d4
183+
info depth 5 score cp 37 nodes 9006 seldepth 10 time 46 pv d2d4
184+
info depth 6 score cp 40 nodes 22119 seldepth 15 time 125 pv d2d4
185+
info depth 7 score cp 29 nodes 68574 seldepth 15 time 262 pv d2d4
186+
*/

src/eval.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
/**
77
* Checks if the game is over and returns the appropriate score.
88
*/
9-
std::pair<bool, Value> checkGameStatus(Position &board);
9+
std::pair<bool, Value> checkGameStatus(Position& board);
1010

1111
// Override `operator<<` for fast printing of Values and Scores
12-
inline std::ostream &operator<<(std::ostream &os, const Value &s) {
12+
inline std::ostream& operator<<(std::ostream& os, const Value& s) {
1313
if (!s.isValid()) {
1414
os << "(invalid score)";
1515
} else {
@@ -22,9 +22,16 @@ inline std::ostream &operator<<(std::ostream &os, const Value &s) {
2222
return os;
2323
}
2424
// This is mostly for debugging purposes
25-
inline std::ostream &operator<<(std::ostream &os, const Score &s) {
26-
os << "S(" << (int)s.mg << ", " << (int)s.eg << ")";
25+
inline std::ostream& operator<<(std::ostream& os, const Score& s) {
26+
os << "S(" << (int) s.mg << ", " << (int) s.eg << ")";
2727
return os;
2828
}
2929

30-
Value evaluate(Position &pos);
30+
Value evaluate(Position& pos);
31+
32+
/**
33+
* Update evaluator network state.
34+
*/
35+
void updateEvaluatorState(const Position& pos, const Move& move);
36+
void updateEvaluatorState(const Position& pos);
37+
void updateEvaluatorState();

src/search.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class Searcher {
113113
stack[i] = Scratchpad();
114114
}
115115
searchInterrupted = false;
116+
updateEvaluatorState(pos);
116117

117118
this->tc = timeControl;
118119

@@ -128,14 +129,15 @@ class Searcher {
128129
// We don't have to waste time searching if there is only one reply in
129130
// competition
130131
if (tc.competitionMode && mp.size() == 1) {
132+
updateEvaluatorState(pos); // refresh evaluator
131133
Value staticEval = evaluate(pos);
132134
result = SearchResult::from(stats, tc, staticEval, mp.pick(), pvTable);
133135
return;
134136
}
135137

136138
Value alpha = Value::matedIn(0);
137139
Value beta = Value::mateIn(0);
138-
Value window = 30;
140+
Value window = 18;
139141
Value bestEvalRoot = Value::none();
140142
Move bestMoveRoot = Move::NO_MOVE;
141143

@@ -210,6 +212,10 @@ class Searcher {
210212
if (hasReachedHardLimit()) {
211213
return alpha;
212214
}
215+
// At root node, refresh evaluator
216+
if (isRootNode) {
217+
updateEvaluatorState(pos);
218+
}
213219
// Go into quiescence search if no more plys are left to search
214220
if (depth <= 0) {
215221
return qsearch<NT>(alpha, beta, 10, ply);
@@ -345,6 +351,7 @@ class Searcher {
345351
continue;
346352
}
347353

354+
updateEvaluatorState(pos, m);
348355
pos.makeMove(m);
349356
stack[ply].currentMove = m.move();
350357
Value score = VALUE_NONE;
@@ -362,6 +369,7 @@ class Searcher {
362369
}
363370

364371
pos.unmakeMove(m);
372+
updateEvaluatorState();
365373
stack[ply].currentMove = 0;
366374

367375
if (score > bestValue) {
@@ -483,12 +491,14 @@ class Searcher {
483491
continue;
484492
}
485493

494+
updateEvaluatorState(pos, m);
486495
pos.makeMove(m);
487496
stack[ply].currentMove = m.move();
488497

489498
const Value v = -qsearch<NT>(-beta, -alpha, depth - 1, ply + 1);
490499

491500
pos.unmakeMove(m);
501+
updateEvaluatorState();
492502
stack[ply].currentMove = 0;
493503

494504
if (v > bestValue) {

src/weight.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ INCBIN(netWeight, "gem.bin");
88

99
namespace nnue {
1010

11-
struct Weight {
11+
struct alignas(32) Weight {
1212
int16_t fc1_weight[768 * 32];
1313
int16_t fc1_bias[32];
1414
int16_t fc2_weight[128];

0 commit comments

Comments
 (0)