Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions src/selfplay/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ const OptionId kNnueBestMoveId{
const OptionId kDeleteFilesId{"delete-files", "",
"Delete the input files after processing."};

const OptionId kDrawValueId{"draw-value", "",
"The value to use as a draw. Can bring the value "
"of the draw closer to a loss or a win. "
"Loss is 0, win is 1."};

const OptionId kLogFileId{"logfile", "LogFile",
"Write log to that file. Special value <stderr> to "
"output the log to the console."};
Expand Down Expand Up @@ -453,6 +458,21 @@ std::string AsNnueString(const Position& p, Move m, float q, int result) {
return out.str();
}

float QWithDrawValue(const V6TrainingData& data, bool useBest, float drawValue) {
const float original_q = useBest ? data.best_q : data.played_q;
const float original_d = useBest ? data.best_d : data.played_d;
// original_q is in [-1, 1]
const float w = (original_q + 1.0f - original_d) / 2.0f;
const float d = original_d;
const float l = w - original_q;
const float q =
w * 1.0f
+ d * drawValue
+ l * 0.0f;
// q in [0, 1], scale to [-1, 1]
return q * 2.0f - 1.0f;
}

struct ProcessFileFlags {
bool delete_files : 1;
bool nnue_best_score : 1;
Expand All @@ -461,7 +481,7 @@ struct ProcessFileFlags {

void ProcessFile(const std::string& file, SyzygyTablebase* tablebase,
std::string outputDir, float distTemp, float distOffset,
float dtzBoost, int newInputFormat,
float dtzBoost, float drawValue, int newInputFormat,
std::string nnue_plain_file, ProcessFileFlags flags) {
// Scope to ensure reader and writer are closed before deleting source file.
{
Expand Down Expand Up @@ -1074,11 +1094,11 @@ void ProcessFile(const std::string& file, SyzygyTablebase* tablebase,
Move m = MoveFromNNIndex(
flags.nnue_best_move ? chunk.best_idx : chunk.played_idx,
TransformForPosition(format, history));
float q = flags.nnue_best_score ? chunk.best_q : chunk.played_q;
const float q = QWithDrawValue(chunk, flags.nnue_best_score, drawValue);
out << AsNnueString(p, m, q, round(chunk.result_q));
} else if (i < moves.size()) {
out << AsNnueString(p, moves[i], chunk.best_q,
round(chunk.result_q));
const float q = QWithDrawValue(chunk, true, drawValue);
out << AsNnueString(p, moves[i], q, round(chunk.result_q));
}
if (i < moves.size()) {
history.Append(moves[i]);
Expand Down Expand Up @@ -1108,7 +1128,7 @@ void ProcessFile(const std::string& file, SyzygyTablebase* tablebase,
void ProcessFiles(const std::vector<std::string>& files,
SyzygyTablebase* tablebase, std::string outputDir,
float distTemp, float distOffset, float dtzBoost,
int newInputFormat, int offset, int mod,
float drawValue, int newInputFormat, int offset, int mod,
std::string nnue_plain_file, ProcessFileFlags flags) {
std::cerr << "Thread: " << offset << " starting" << std::endl;
for (int i = offset; i < files.size(); i += mod) {
Expand All @@ -1117,7 +1137,7 @@ void ProcessFiles(const std::vector<std::string>& files,
continue;
}
ProcessFile(files[i], tablebase, outputDir, distTemp, distOffset, dtzBoost,
newInputFormat, nnue_plain_file, flags);
drawValue, newInputFormat, nnue_plain_file, flags);
}
}

Expand Down Expand Up @@ -1214,6 +1234,7 @@ void RescoreLoop::RunLoop() {
options_.Add<BoolOption>(kNnueBestScoreId) = true;
options_.Add<BoolOption>(kNnueBestMoveId) = false;
options_.Add<BoolOption>(kDeleteFilesId) = true;
options_.Add<FloatOption>(kDrawValueId, 0.0f, 1.0f) = 0.5f;

SelfPlayTournament::PopulateOptions(&options_);

Expand Down Expand Up @@ -1299,7 +1320,9 @@ void RescoreLoop::RunLoop() {
options_.GetOptionsDict().Get<std::string>(kOutputDirId),
options_.GetOptionsDict().Get<float>(kTempId),
options_.GetOptionsDict().Get<float>(kDistributionOffsetId),
dtz_boost, options_.GetOptionsDict().Get<int>(kNewInputFormatId),
dtz_boost,
options_.GetOptionsDict().Get<float>(kDrawValueId),
options_.GetOptionsDict().Get<int>(kNewInputFormatId),
offset_val, threads,
options_.GetOptionsDict().Get<std::string>(kNnuePlainFileId),
flags);
Expand All @@ -1315,6 +1338,7 @@ void RescoreLoop::RunLoop() {
options_.GetOptionsDict().Get<float>(kTempId),
options_.GetOptionsDict().Get<float>(kDistributionOffsetId),
dtz_boost,
options_.GetOptionsDict().Get<float>(kDrawValueId),
options_.GetOptionsDict().Get<int>(kNewInputFormatId), 0, 1,
options_.GetOptionsDict().Get<std::string>(kNnuePlainFileId),
flags);
Expand Down