Skip to content

Commit 062965d

Browse files
committed
feat: better net and more optimizations
1 parent 5827a37 commit 062965d

File tree

4 files changed

+93
-83
lines changed

4 files changed

+93
-83
lines changed

gem.bin

48.3 KB
Binary file not shown.

src/eval.cpp

Lines changed: 88 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -27,102 +27,113 @@ std::pair<bool, Value> checkGameStatus(Position& board) {
2727

2828
namespace {
2929

30-
class EvalNet {
30+
class EvaluatorNet {
3131
private:
32-
const nnue::Weight* w_;
33-
int16_t fc1_w[16 * 768];
34-
int16_t fc2_w[16 * 16];
32+
const nnue::Weight& w_;
3533

3634
public:
37-
EvalNet(const nnue::Weight* weight) : w_(weight) { init(); }
35+
EvaluatorNet() : w_(*nnue::weight) {}
3836

39-
void init() {
40-
for (size_t i = 0; i < 16 * 768; ++i) {
41-
fc1_w[i] = w_->fc1_weight[i] - 128;
37+
int operator()(const int8_t* __restrict__ x1, const int8_t* __restrict__ x2) const {
38+
// Set first layer bias
39+
int32_t accumulator1[32] = {0};
40+
for (int i = 0; i < 32; ++i) {
41+
accumulator1[i] = w_.fc1_bias[i]; // auto-vectorizable
4242
}
43-
for (size_t i = 0; i < 16 * 16; ++i) {
44-
fc2_w[i] = w_->fc2_weight[i] - 128;
43+
int32_t accumulator2[32] = {0};
44+
for (int i = 0; i < 32; ++i) {
45+
accumulator2[i] = w_.fc1_bias[i];
4546
}
46-
}
47-
48-
int operator()(const InputVector& x) const {
49-
// --- Layer 1 ---
50-
int8_t a1[16];
51-
for (int i = 0; i < 16; ++i) {
52-
int acc = ((int) (w_->fc1_bias[i]) - 128);
53-
for (int j = 0; j < 768; ++j) {
54-
acc += fc1_w[i * 768 + j] * x[j]; // x[j]=0 or 1
47+
// Accumulate with weight
48+
for (int i = 0; i < 768; ++i) {
49+
for (int j = 0; j < 32; ++j) {
50+
accumulator1[j] += w_.fc1_weight[i * 32 + j] * x1[i];
5551
}
56-
acc = std::max(0, std::min(acc, 127));
57-
a1[i] = acc;
5852
}
59-
60-
// --- Layer 2 ---
61-
int8_t a2[16];
62-
for (int i = 0; i < 16; ++i) {
63-
int32_t acc = (int32_t(w_->fc2_bias[i]) - 128) << 7;
64-
for (int j = 0; j < 16; ++j) {
65-
acc += fc2_w[i * 16 + j] * (int) (a1[j]);
53+
for (int i = 0; i < 768; ++i) {
54+
for (int j = 0; j < 32; ++j) {
55+
accumulator2[j] += w_.fc1_weight[i * 32 + j] * x2[i];
6656
}
67-
acc = std::max(0, std::min(acc, 16256));
68-
a2[i] = (acc >> 7);
57+
}
58+
// Clamp to 0 and 32767
59+
for (int i = 0; i < 32; ++i) {
60+
accumulator1[i] = std::clamp(accumulator1[i], 0, 32767);
61+
}
62+
for (int i = 0; i < 32; ++i) {
63+
accumulator2[i] = std::clamp(accumulator2[i], 0, 32767);
64+
}
65+
// Also compute features with square
66+
int32_t acc1_sqr[32] = {0};
67+
for (int i = 0; i < 32; ++i) {
68+
acc1_sqr[i] = accumulator1[i] * accumulator1[i];
69+
}
70+
for (int i = 0; i < 32; ++i) {
71+
acc1_sqr[i] >>= 15;
72+
}
73+
int32_t acc2_sqr[32] = {0};
74+
for (int i = 0; i < 32; ++i) {
75+
acc2_sqr[i] = accumulator2[i] * accumulator2[i];
76+
}
77+
for (int i = 0; i < 32; ++i) {
78+
acc2_sqr[i] >>= 15;
6979
}
7080

71-
// --- Layer 3 (final) ---
72-
int32_t acc = (int32_t(w_->fc3_bias[0]) - 128) << 7;
73-
for (int j = 0; j < 16; ++j) {
74-
int16_t wq = int16_t(w_->fc3_weight[j]) - 128;
75-
acc += wq * (int) (a2[j]);
81+
// Values are ready to pass through dense layer.
82+
int32_t acc = w_.fc2_bias; // set bias
83+
int32_t temp = 0;
84+
for (int i = 0; i < 32; ++i) {
85+
temp += accumulator1[i] * w_.fc2_weight[i];
7686
}
87+
acc += temp / 127;
88+
temp = 0;
89+
for (int i = 0; i < 32; ++i) {
90+
temp += acc1_sqr[i] * w_.fc2_weight[i + 32];
91+
}
92+
acc += temp / 127;
93+
temp = 0;
94+
for (int i = 0; i < 32; ++i) {
95+
temp += accumulator2[i] * w_.fc2_weight[i + 64];
96+
}
97+
acc += temp / 127;
98+
temp = 0;
99+
for (int i = 0; i < 32; ++i) {
100+
temp += acc2_sqr[i] * w_.fc2_weight[i + 96];
101+
}
102+
acc += temp / 127;
77103

78-
return acc >> 7;
104+
return acc / 152;
79105
}
80106
};
81107

82-
InputVector getInputVectorFor(const Board& pos) {
83-
InputVector vec;
84-
vec.fill(0);
108+
void getInputRepresentationFor(const Board& pos, int8_t* v1, int8_t* v2) {
109+
int8_t white[768] = {0};
110+
int8_t black[768] = {0};
85111

86112
const bool stm_white = (pos.sideToMove() == WHITE);
87113

88-
auto scan = [&](Bitboard bb, bool flip, int idx) {
114+
auto scan = [&](Bitboard bb, bool isWhite, int idx) {
89115
while (bb) {
90-
int sq = bb.pop();
91-
if (flip)
92-
sq ^= 56;
93-
vec[sq * 12 + idx * 2 + (flip ? 1 : 0)] = 1;
116+
int sq = bb.pop();
117+
int whiteIndex = ((int) (!isWhite) * 6 + idx) * 64 + sq;
118+
white[whiteIndex] = 1;
119+
int blackIndex = ((int) (isWhite) * 6 + idx) * 64 + sq;
120+
black[blackIndex] = 1;
94121
}
95122
};
96123

97-
Bitboard wP = pos.pieces(TYPE_PAWN, WHITE);
98-
Bitboard wN = pos.pieces(TYPE_KNIGHT, WHITE);
99-
Bitboard wB = pos.pieces(TYPE_BISHOP, WHITE);
100-
Bitboard wR = pos.pieces(TYPE_ROOK, WHITE);
101-
Bitboard wQ = pos.pieces(TYPE_QUEEN, WHITE);
102-
Bitboard wK = pos.pieces(TYPE_KING, WHITE);
103-
104-
Bitboard bP = pos.pieces(TYPE_PAWN, BLACK);
105-
Bitboard bN = pos.pieces(TYPE_KNIGHT, BLACK);
106-
Bitboard bB = pos.pieces(TYPE_BISHOP, BLACK);
107-
Bitboard bR = pos.pieces(TYPE_ROOK, BLACK);
108-
Bitboard bQ = pos.pieces(TYPE_QUEEN, BLACK);
109-
Bitboard bK = pos.pieces(TYPE_KING, BLACK);
110-
111-
scan(wP, !stm_white, 0);
112-
scan(wN, !stm_white, 1);
113-
scan(wB, !stm_white, 2);
114-
scan(wR, !stm_white, 3);
115-
scan(wQ, !stm_white, 4);
116-
scan(wK, !stm_white, 5);
117-
118-
scan(bP, stm_white, 0);
119-
scan(bN, stm_white, 1);
120-
scan(bB, stm_white, 2);
121-
scan(bR, stm_white, 3);
122-
scan(bQ, stm_white, 4);
123-
scan(bK, stm_white, 5);
124+
for (int p_index = 0; p_index < 6; ++p_index) {
125+
PieceType pt = (PieceType::underlying) p_index;
126+
scan(pos.pieces(pt, WHITE), true, p_index);
127+
scan(pos.pieces(pt, BLACK), false, p_index);
128+
}
124129

125-
return vec;
130+
if (stm_white) {
131+
memcpy(v1, white, 768);
132+
memcpy(v2, black, 768);
133+
} else {
134+
memcpy(v2, white, 768);
135+
memcpy(v1, black, 768);
136+
}
126137
}
127138

128139
} // namespace
@@ -131,6 +142,9 @@ InputVector getInputVectorFor(const Board& pos) {
131142
* Main evaluation function.
132143
*/
133144
Value evaluate(Position& pos) {
134-
static EvalNet net(nnue::weight);
135-
return net(getInputVectorFor(pos));
145+
static EvaluatorNet net;
146+
int8_t vec1[768];
147+
int8_t vec2[768];
148+
getInputRepresentationFor(pos, vec1, vec2);
149+
return net(vec1, vec2);
136150
}

src/weight.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
#include <array>
55
#include <cstdint>
66

7-
INCBIN(netWeight, "weight.bin");
8-
9-
using InputVector = std::array<int8_t, 768>;
7+
INCBIN(netWeight, "gem.bin");
108

119
namespace nnue {
1210

1311
struct Weight {
14-
uint8_t fc1_weight[16 * 768];
15-
uint8_t fc1_bias[16];
16-
uint8_t fc2_weight[16 * 16];
17-
uint8_t fc2_bias[16];
18-
uint8_t fc3_weight[16];
19-
uint8_t fc3_bias[1];
12+
int16_t fc1_weight[768 * 32];
13+
int16_t fc1_bias[32];
14+
int16_t fc2_weight[128];
15+
int16_t fc2_bias;
2016
};
2117

2218
const Weight* weight = reinterpret_cast<const Weight*>(gnetWeightData);

weight.bin

-12.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)