diff --git a/src/selfplay/loop.cc b/src/selfplay/loop.cc index 18accd3940..706b583c6e 100644 --- a/src/selfplay/loop.cc +++ b/src/selfplay/loop.cc @@ -92,6 +92,11 @@ const OptionId kNnueBestMoveId{ const OptionId kDeleteFilesId{"delete-files", "", "Delete the input files after processing."}; +const OptionId kDrawValueId{"draw-value", "", + "The value to use as a draw. Can bring the value " + "of the draw closer to a loss or a win. " + "Loss is 0, win is 1."}; + const OptionId kLogFileId{"logfile", "LogFile", "Write log to that file. Special value to " "output the log to the console."}; @@ -453,6 +458,21 @@ std::string AsNnueString(const Position& p, Move m, float q, int result) { return out.str(); } +float QWithDrawValue(const V6TrainingData& data, bool useBest, float drawValue) { + const float original_q = useBest ? data.best_q : data.played_q; + const float original_d = useBest ? data.best_d : data.played_d; + // original_q is in [-1, 1] + const float w = (original_q + 1.0f - original_d) / 2.0f; + const float d = original_d; + const float l = w - original_q; + const float q = + w * 1.0f + + d * drawValue + + l * 0.0f; + // q in [0, 1], scale to [-1, 1] + return q * 2.0f - 1.0f; +} + struct ProcessFileFlags { bool delete_files : 1; bool nnue_best_score : 1; @@ -461,7 +481,7 @@ struct ProcessFileFlags { void ProcessFile(const std::string& file, SyzygyTablebase* tablebase, std::string outputDir, float distTemp, float distOffset, - float dtzBoost, int newInputFormat, + float dtzBoost, float drawValue, int newInputFormat, std::string nnue_plain_file, ProcessFileFlags flags) { // Scope to ensure reader and writer are closed before deleting source file. { @@ -1074,11 +1094,11 @@ void ProcessFile(const std::string& file, SyzygyTablebase* tablebase, Move m = MoveFromNNIndex( flags.nnue_best_move ? chunk.best_idx : chunk.played_idx, TransformForPosition(format, history)); - float q = flags.nnue_best_score ? chunk.best_q : chunk.played_q; + const float q = QWithDrawValue(chunk, flags.nnue_best_score, drawValue); out << AsNnueString(p, m, q, round(chunk.result_q)); } else if (i < moves.size()) { - out << AsNnueString(p, moves[i], chunk.best_q, - round(chunk.result_q)); + const float q = QWithDrawValue(chunk, true, drawValue); + out << AsNnueString(p, moves[i], q, round(chunk.result_q)); } if (i < moves.size()) { history.Append(moves[i]); @@ -1108,7 +1128,7 @@ void ProcessFile(const std::string& file, SyzygyTablebase* tablebase, void ProcessFiles(const std::vector& files, SyzygyTablebase* tablebase, std::string outputDir, float distTemp, float distOffset, float dtzBoost, - int newInputFormat, int offset, int mod, + float drawValue, int newInputFormat, int offset, int mod, std::string nnue_plain_file, ProcessFileFlags flags) { std::cerr << "Thread: " << offset << " starting" << std::endl; for (int i = offset; i < files.size(); i += mod) { @@ -1117,7 +1137,7 @@ void ProcessFiles(const std::vector& files, continue; } ProcessFile(files[i], tablebase, outputDir, distTemp, distOffset, dtzBoost, - newInputFormat, nnue_plain_file, flags); + drawValue, newInputFormat, nnue_plain_file, flags); } } @@ -1214,6 +1234,7 @@ void RescoreLoop::RunLoop() { options_.Add(kNnueBestScoreId) = true; options_.Add(kNnueBestMoveId) = false; options_.Add(kDeleteFilesId) = true; + options_.Add(kDrawValueId, 0.0f, 1.0f) = 0.5f; SelfPlayTournament::PopulateOptions(&options_); @@ -1299,7 +1320,9 @@ void RescoreLoop::RunLoop() { options_.GetOptionsDict().Get(kOutputDirId), options_.GetOptionsDict().Get(kTempId), options_.GetOptionsDict().Get(kDistributionOffsetId), - dtz_boost, options_.GetOptionsDict().Get(kNewInputFormatId), + dtz_boost, + options_.GetOptionsDict().Get(kDrawValueId), + options_.GetOptionsDict().Get(kNewInputFormatId), offset_val, threads, options_.GetOptionsDict().Get(kNnuePlainFileId), flags); @@ -1315,6 +1338,7 @@ void RescoreLoop::RunLoop() { options_.GetOptionsDict().Get(kTempId), options_.GetOptionsDict().Get(kDistributionOffsetId), dtz_boost, + options_.GetOptionsDict().Get(kDrawValueId), options_.GetOptionsDict().Get(kNewInputFormatId), 0, 1, options_.GetOptionsDict().Get(kNnuePlainFileId), flags);