@@ -27,102 +27,113 @@ std::pair<bool, Value> checkGameStatus(Position& board) {
2727
2828namespace {
2929
30- class EvalNet {
30+ class EvaluatorNet {
3131private:
32- const nnue::Weight* w_;
33- int16_t fc1_w[16 * 768 ];
34- int16_t fc2_w[16 * 16 ];
32+ const nnue::Weight& w_;
3533
3634public:
37- EvalNet ( const nnue::Weight* weight ) : w_(weight) { init (); }
35+ EvaluatorNet ( ) : w_(*nnue:: weight) {}
3836
39- void init () {
40- for (size_t i = 0 ; i < 16 * 768 ; ++i) {
41- fc1_w[i] = w_->fc1_weight [i] - 128 ;
37+ int operator ()(const int8_t * __restrict__ x1, const int8_t * __restrict__ x2) const {
38+ // Set first layer bias
39+ int32_t accumulator1[32 ] = {0 };
40+ for (int i = 0 ; i < 32 ; ++i) {
41+ accumulator1[i] = w_.fc1_bias [i]; // auto-vectorizable
4242 }
43- for (size_t i = 0 ; i < 16 * 16 ; ++i) {
44- fc2_w[i] = w_->fc2_weight [i] - 128 ;
43+ int32_t accumulator2[32 ] = {0 };
44+ for (int i = 0 ; i < 32 ; ++i) {
45+ accumulator2[i] = w_.fc1_bias [i];
4546 }
46- }
47-
48- int operator ()(const InputVector& x) const {
49- // --- Layer 1 ---
50- int8_t a1[16 ];
51- for (int i = 0 ; i < 16 ; ++i) {
52- int acc = ((int ) (w_->fc1_bias [i]) - 128 );
53- for (int j = 0 ; j < 768 ; ++j) {
54- acc += fc1_w[i * 768 + j] * x[j]; // x[j]=0 or 1
47+ // Accumulate with weight
48+ for (int i = 0 ; i < 768 ; ++i) {
49+ for (int j = 0 ; j < 32 ; ++j) {
50+ accumulator1[j] += w_.fc1_weight [i * 32 + j] * x1[i];
5551 }
56- acc = std::max (0 , std::min (acc, 127 ));
57- a1[i] = acc;
5852 }
59-
60- // --- Layer 2 ---
61- int8_t a2[16 ];
62- for (int i = 0 ; i < 16 ; ++i) {
63- int32_t acc = (int32_t (w_->fc2_bias [i]) - 128 ) << 7 ;
64- for (int j = 0 ; j < 16 ; ++j) {
65- acc += fc2_w[i * 16 + j] * (int ) (a1[j]);
53+ for (int i = 0 ; i < 768 ; ++i) {
54+ for (int j = 0 ; j < 32 ; ++j) {
55+ accumulator2[j] += w_.fc1_weight [i * 32 + j] * x2[i];
6656 }
67- acc = std::max (0 , std::min (acc, 16256 ));
68- a2[i] = (acc >> 7 );
57+ }
58+ // Clamp to 0 and 32767
59+ for (int i = 0 ; i < 32 ; ++i) {
60+ accumulator1[i] = std::clamp (accumulator1[i], 0 , 32767 );
61+ }
62+ for (int i = 0 ; i < 32 ; ++i) {
63+ accumulator2[i] = std::clamp (accumulator2[i], 0 , 32767 );
64+ }
65+ // Also compute features with square
66+ int32_t acc1_sqr[32 ] = {0 };
67+ for (int i = 0 ; i < 32 ; ++i) {
68+ acc1_sqr[i] = accumulator1[i] * accumulator1[i];
69+ }
70+ for (int i = 0 ; i < 32 ; ++i) {
71+ acc1_sqr[i] >>= 15 ;
72+ }
73+ int32_t acc2_sqr[32 ] = {0 };
74+ for (int i = 0 ; i < 32 ; ++i) {
75+ acc2_sqr[i] = accumulator2[i] * accumulator2[i];
76+ }
77+ for (int i = 0 ; i < 32 ; ++i) {
78+ acc2_sqr[i] >>= 15 ;
6979 }
7080
71- // --- Layer 3 (final) ---
72- int32_t acc = ( int32_t (w_-> fc3_bias [ 0 ]) - 128 ) << 7 ;
73- for ( int j = 0 ; j < 16 ; ++j) {
74- int16_t wq = int16_t (w_-> fc3_weight [j]) - 128 ;
75- acc += wq * ( int ) (a2[j]) ;
81+ // Values are ready to pass through dense layer.
82+ int32_t acc = w_. fc2_bias ; // set bias
83+ int32_t temp = 0 ;
84+ for ( int i = 0 ; i < 32 ; ++i) {
85+ temp += accumulator1[i] * w_. fc2_weight [i] ;
7686 }
87+ acc += temp / 127 ;
88+ temp = 0 ;
89+ for (int i = 0 ; i < 32 ; ++i) {
90+ temp += acc1_sqr[i] * w_.fc2_weight [i + 32 ];
91+ }
92+ acc += temp / 127 ;
93+ temp = 0 ;
94+ for (int i = 0 ; i < 32 ; ++i) {
95+ temp += accumulator2[i] * w_.fc2_weight [i + 64 ];
96+ }
97+ acc += temp / 127 ;
98+ temp = 0 ;
99+ for (int i = 0 ; i < 32 ; ++i) {
100+ temp += acc2_sqr[i] * w_.fc2_weight [i + 96 ];
101+ }
102+ acc += temp / 127 ;
77103
78- return acc >> 7 ;
104+ return acc / 152 ;
79105 }
80106};
81107
82- InputVector getInputVectorFor (const Board& pos) {
83- InputVector vec ;
84- vec. fill ( 0 ) ;
108+ void getInputRepresentationFor (const Board& pos, int8_t * v1, int8_t * v2 ) {
109+ int8_t white[ 768 ] = { 0 } ;
110+ int8_t black[ 768 ] = { 0 } ;
85111
86112 const bool stm_white = (pos.sideToMove () == WHITE);
87113
88- auto scan = [&](Bitboard bb, bool flip , int idx) {
114+ auto scan = [&](Bitboard bb, bool isWhite , int idx) {
89115 while (bb) {
90- int sq = bb.pop ();
91- if (flip)
92- sq ^= 56 ;
93- vec[sq * 12 + idx * 2 + (flip ? 1 : 0 )] = 1 ;
116+ int sq = bb.pop ();
117+ int whiteIndex = ((int ) (!isWhite) * 6 + idx) * 64 + sq;
118+ white[whiteIndex] = 1 ;
119+ int blackIndex = ((int ) (isWhite) * 6 + idx) * 64 + sq;
120+ black[blackIndex] = 1 ;
94121 }
95122 };
96123
97- Bitboard wP = pos.pieces (TYPE_PAWN, WHITE);
98- Bitboard wN = pos.pieces (TYPE_KNIGHT, WHITE);
99- Bitboard wB = pos.pieces (TYPE_BISHOP, WHITE);
100- Bitboard wR = pos.pieces (TYPE_ROOK, WHITE);
101- Bitboard wQ = pos.pieces (TYPE_QUEEN, WHITE);
102- Bitboard wK = pos.pieces (TYPE_KING, WHITE);
103-
104- Bitboard bP = pos.pieces (TYPE_PAWN, BLACK);
105- Bitboard bN = pos.pieces (TYPE_KNIGHT, BLACK);
106- Bitboard bB = pos.pieces (TYPE_BISHOP, BLACK);
107- Bitboard bR = pos.pieces (TYPE_ROOK, BLACK);
108- Bitboard bQ = pos.pieces (TYPE_QUEEN, BLACK);
109- Bitboard bK = pos.pieces (TYPE_KING, BLACK);
110-
111- scan (wP, !stm_white, 0 );
112- scan (wN, !stm_white, 1 );
113- scan (wB, !stm_white, 2 );
114- scan (wR, !stm_white, 3 );
115- scan (wQ, !stm_white, 4 );
116- scan (wK, !stm_white, 5 );
117-
118- scan (bP, stm_white, 0 );
119- scan (bN, stm_white, 1 );
120- scan (bB, stm_white, 2 );
121- scan (bR, stm_white, 3 );
122- scan (bQ, stm_white, 4 );
123- scan (bK, stm_white, 5 );
124+ for (int p_index = 0 ; p_index < 6 ; ++p_index) {
125+ PieceType pt = (PieceType::underlying) p_index;
126+ scan (pos.pieces (pt, WHITE), true , p_index);
127+ scan (pos.pieces (pt, BLACK), false , p_index);
128+ }
124129
125- return vec;
130+ if (stm_white) {
131+ memcpy (v1, white, 768 );
132+ memcpy (v2, black, 768 );
133+ } else {
134+ memcpy (v2, white, 768 );
135+ memcpy (v1, black, 768 );
136+ }
126137}
127138
128139} // namespace
@@ -131,6 +142,9 @@ InputVector getInputVectorFor(const Board& pos) {
131142 * Main evaluation function.
132143 */
133144Value evaluate (Position& pos) {
134- static EvalNet net (nnue::weight);
135- return net (getInputVectorFor (pos));
145+ static EvaluatorNet net;
146+ int8_t vec1[768 ];
147+ int8_t vec2[768 ];
148+ getInputRepresentationFor (pos, vec1, vec2);
149+ return net (vec1, vec2);
136150}
0 commit comments