Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 104 additions & 100 deletions src/eval_constants.hpp

Large diffs are not rendered by default.

26 changes: 25 additions & 1 deletion src/eval_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}

[[nodiscard]] inline auto mg() const {
const auto mg = static_cast<u16>(m_score);

Check warning on line 43 in src/eval_types.hpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/eval_types.hpp:43:20 [readability-identifier-naming]

invalid case style for constant 'mg'

i16 v{};
std::memcpy(&v, &mg, sizeof(mg));
Expand Down Expand Up @@ -101,18 +101,33 @@
return static_cast<Value>((mg() * alpha + eg() * (max - alpha)) / max);
}

// complexity_add
PScore complexity_add(Score val) {
const Score e = eg();
if (e == 0) {
return *this;
}

const Score sum = e + val;
return PScore{
mg(), static_cast<Score>((e > 0) ? std::max(sum, Score{0}) : std::min(sum, Score{0}))};
}


friend std::ostream& operator<<(std::ostream& stream, const PScore& score) {
stream << "(" << score.mg() << "\t" << score.eg() << ")";
return stream;
}
};

using PParam = PScore;
using VParam = Score;
#else

using Score = Autograd::ValueHandle;
using PScore = Autograd::PairHandle;
using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter
using PParam = Autograd::PairPlaceholder; // Handle for the TUNABLE parameter
using VParam = Autograd::ValuePlaceholder; // Handle for the TUNABLE parameter

#endif

Expand All @@ -124,6 +139,12 @@
// Constant scalar pair (mg, eg)
#define CS(a, b) Autograd::PairPlaceholder::create((a), (b))

// Tunable scalar
#define V(a) Autograd::ValuePlaceholder::create_tunable((a))

// Constant scalar
#define CV(a) Autograd::ValuePlaceholder::create((a))

// Zero pair FOR PARAMETERS (e.g., in an array)
#define PPARAM_ZERO Autograd::PairPlaceholder::create(0, 0)

Expand All @@ -134,6 +155,9 @@
// ... (non-tuning definitions) ...
#define S(a, b) PScore((a), (b))
#define CS(a, b) PScore((a), (b))

#define V(a) Value((a))
#define CV(a) Value((a))
#define PPARAM_ZERO PScore(0, 0)
#define PSCORE_ZERO PScore(0, 0)
#endif
Expand Down
6 changes: 6 additions & 0 deletions src/evaltune_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ int main() {
};
print_sigmoid("KING_SAFETY_ACTIVATION", KING_SAFETY_ACTIVATION, 32);

std::cout << std::endl;

std::cout << "inline VParam WINNABLE_PAWNS = " << WINNABLE_PAWNS << ";\n";
std::cout << "inline VParam WINNABLE_BIAS = " << WINNABLE_BIAS << ";\n";
std::cout << std::endl;

#endif
const auto end = time::Clock::now();
std::cout << "// Epoch duration: " << time::cast<time::FloatSeconds>(end - start).count()
Expand Down
18 changes: 18 additions & 0 deletions src/evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,20 @@ PScore king_safety_activation(const Position& pos, PScore& king_safety_score) {
return activated;
}

PScore apply_winnable(const Position& pos, PScore& score) {

i32 pawn_count = (pos.bitboard_for(Color::White, PieceType::Pawn)
| pos.bitboard_for(Color::Black, PieceType::Pawn))
.ipopcount();
Score winnable = WINNABLE_PAWNS * pawn_count + WINNABLE_BIAS;

if (score.eg() < 0) {
winnable = -winnable;
}

return score.complexity_add(winnable);
}

Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
const Color us = pos.active_color();
usize phase = pos.piece_count(Color::White, PieceType::Knight)
Expand Down Expand Up @@ -450,6 +464,10 @@ Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
- king_safety_activation<Color::Black>(pos, black_king_attack_total);

eval += (us == Color::White) ? TEMPO_VAL : -TEMPO_VAL;

// Winnable
eval = apply_winnable(pos, eval);

return static_cast<Score>(eval.phase<24>(static_cast<i32>(phase)));
};

Expand Down
48 changes: 48 additions & 0 deletions src/tuning/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs)
case OpType::ValueDivPair:
res = f64x2::scalar_div(v, pair_val);
break;
case OpType::PairAddClampedSecond: {
f64 add_res = pair_val.second() + v;
if (pair_val.second() > 0) {
res = f64x2::make(pair_val.first(), std::max(0.0, add_res));
} else if (pair_val.second() < 0) {
res = f64x2::make(pair_val.first(), std::min(0.0, add_res));
} else {
res = pair_val;
}
break;
}
default:
break;
}
Expand Down Expand Up @@ -488,6 +499,43 @@ void Graph::backward() {
pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_rhs);
break;
}
case OpType::PairAddClampedSecond: {
const f64x2 grad_out = pair_grads[out_idx];
f64x2 val_lhs = pair_vals[node.lhs()];
f64 val_rhs = vals[node.rhs()];

// First component always passes through
f64x2 grad_pair = f64x2::make(grad_out.first(), 0.0);
f64 grad_rhs = 0.0;

// For the second component
if (val_lhs.second() > 0) {
// Forward: res.second = max(0.0, val_lhs.second() + val_rhs)
f64 add_res = val_lhs.second() + val_rhs;
if (add_res > 0.0) {
// Gradient flows through both inputs
grad_pair = f64x2::make(grad_out.first(), grad_out.second());
grad_rhs = grad_out.second();
}
} else if (val_lhs.second() < 0) {
// Forward: res.second = min(0.0, val_lhs.second() + val_rhs)
f64 add_res = val_lhs.second() + val_rhs;
if (add_res < 0.0) {
// Gradient flows through both inputs
grad_pair = f64x2::make(grad_out.first(), grad_out.second());
grad_rhs = grad_out.second();
}
} else {
// val_lhs.second() == 0: output = input (no addition)
grad_pair = grad_out;
grad_rhs = 0.0;
}

pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_pair);
grads[node.rhs()] += grad_rhs;
break;
}


default:
unreachable();
Expand Down
1 change: 1 addition & 0 deletions src/tuning/operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum class OpType : u32 {
ValueMulPair,
PairDivValue,
ValueDivPair,
PairAddClampedSecond, // For complexity

// Pair-Pair Ops
PairMulPair,
Expand Down
16 changes: 16 additions & 0 deletions src/tuning/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ f64 PairHandle::first() const {
f64 PairHandle::second() const {
return get_values().second();
}
f64 PairHandle::mg() const {
return get_values().first();
}
f64 PairHandle::eg() const {
return get_values().second();
}

void PairHandle::set_values(const f64x2& v) const {
Graph::get().set_pair_values(index, v);
Expand Down Expand Up @@ -204,6 +210,11 @@ std::ostream& operator<<(std::ostream& os, const PairHandle& p) {
return os;
}

std::ostream& operator<<(std::ostream& os, const ValueHandle& v) {
os << "V(" << std::round(v.get_value()) << ")";
return os;
}

// Value Inplaces
ValueHandle& operator+=(ValueHandle& a, ValueHandle b) {
a = a + b;
Expand Down Expand Up @@ -265,4 +276,9 @@ PairHandle& operator/=(PairHandle& a, ValueHandle v) {
return a;
}


PairHandle PairHandle::complexity_add(ValueHandle value) const {
return Graph::get().record_pair_value(OpType::PairAddClampedSecond, *this, value);
}

} // namespace Clockwork::Autograd
5 changes: 5 additions & 0 deletions src/tuning/value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ struct PairHandle {
f64x2 get_values() const;
f64x2 get_gradients() const;
f64 first() const;
f64 mg() const;
f64 second() const;
f64 eg() const;
void set_values(const f64x2& v) const;
void set_values(f64 f, f64 s) const;
void zero_grad() const;
Expand All @@ -73,6 +75,8 @@ struct PairHandle {
}

PairHandle sigmoid() const;

PairHandle complexity_add(ValueHandle value) const;
};

// Operation decls
Expand Down Expand Up @@ -105,6 +109,7 @@ PairHandle operator*(PairHandle a, PairHandle b);
PairHandle operator/(PairHandle a, ValueHandle v);
PairHandle operator/(ValueHandle v, PairHandle a);
std::ostream& operator<<(std::ostream& os, const PairHandle& p);
std::ostream& operator<<(std::ostream& os, const ValueHandle& v);

// Value Inplaces
ValueHandle& operator+=(ValueHandle& a, ValueHandle b);
Expand Down
Loading