diff --git a/src/learn/learner.cpp b/src/learn/learner.cpp index 4f8b3fee..a139bb5f 100644 --- a/src/learn/learner.cpp +++ b/src/learn/learner.cpp @@ -137,6 +137,8 @@ bool convert_teacher_signal_to_winning_probability = true; // generation and training don't work well. // https://discordapp.com/channels/435943710472011776/733545871911813221/748524079761326192 bool use_raw_nnue_eval = true; +// Using WDL with win rate model instead of sigmoid +bool use_wdl = false; // ----------------------------------- // write phase file @@ -1168,6 +1170,45 @@ double winning_percentage(double value) // = sigmoid(Eval/4*ln(10)) return sigmoid(value * winning_probability_coefficient); } + +// A function that converts the evaluation value to the winning rate [0,1] +double winning_percentage_wdl(double value, int ply) +{ + double wdl_w = UCI::win_rate_model_double( value, ply); + double wdl_l = UCI::win_rate_model_double(-value, ply); + double wdl_d = 1000.0 - wdl_w - wdl_l; + + return (wdl_w + wdl_d / 2.0) / 1000.0; +} + +// A function that converts the evaluation value to the winning rate [0,1] +double winning_percentage(double value, int ply) +{ + if (use_wdl) { + return winning_percentage_wdl(value, ply); + } + else { + return winning_percentage(value); + } +} + +double calc_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply) +{ + double p = deep_win_rate; + double q = winning_percentage(shallow_eval, ply); + return -p * std::log(q) - (1 - p) * std::log(1 - q); +} + +double calc_d_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply) +{ + constexpr double epsilon = 0.000001; + double y1 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval , ply); + double y2 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval + epsilon, ply); + + // Divide by the winning_probability_coefficient to match scale with the sigmoidal win rate + return ((y2 - y1) / epsilon) / winning_probability_coefficient; +} + double dsigmoid(double x) { // Sigmoid function @@ -1269,11 +1310,11 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps // Scale to [dest_score_min_value, dest_score_max_value]. scaled_teacher_signal = scaled_teacher_signal * (dest_score_max_value - dest_score_min_value) + dest_score_min_value; - const double q = winning_percentage(shallow); + const double q = winning_percentage(shallow, psv.gamePly); // Teacher winning probability. double p = scaled_teacher_signal; if (convert_teacher_signal_to_winning_probability) { - p = winning_percentage(scaled_teacher_signal); + p = winning_percentage(scaled_teacher_signal, psv.gamePly); } // Use 1 as the correction term if the expected win rate is 1, 0 if you lose, and 0.5 if you draw. @@ -1283,9 +1324,17 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps // If the evaluation value in deep search exceeds ELMO_LAMBDA_LIMIT, apply ELMO_LAMBDA2 instead of ELMO_LAMBDA. const double lambda = (abs(teacher_signal) >= ELMO_LAMBDA_LIMIT) ? ELMO_LAMBDA2 : ELMO_LAMBDA; - // Use the actual win rate as a correction term. - // This is the idea of ​​elmo (WCSC27), modern O-parts. - const double grad = lambda * (q - p) + (1.0 - lambda) * (q - t); + double grad; + if (use_wdl) { + double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly); + double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly); + grad = lambda * dce_p + (1.0 - lambda) * dce_t; + } + else { + // Use the actual win rate as a correction term. + // This is the idea of ​​elmo (WCSC27), modern O-parts. + grad = lambda * (q - p) + (1.0 - lambda) * (q - t); + } return grad; } @@ -3174,6 +3223,8 @@ void learn(Position&, istringstream& is) else if (option == "winning_probability_coefficient") is >> winning_probability_coefficient; // Discount rate else if (option == "discount_rate") is >> discount_rate; + // Using WDL with win rate model instead of sigmoid + else if (option == "use_wdl") is >> use_wdl; // No learning of KK/KKP/KPP/KPPP. else if (option == "freeze_kk") is >> freeze[0]; diff --git a/src/uci.cpp b/src/uci.cpp index 8972cec9..00941040 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -238,27 +238,34 @@ namespace { // The win rate model returns the probability (per mille) of winning given an eval // and a game-ply. The model fits rather accurately the LTC fishtest statistics. int win_rate_model(Value v, int ply) { - - // The model captures only up to 240 plies, so limit input (and rescale) - double m = std::min(240, ply) / 64.0; - - // Coefficients of a 3rd order polynomial fit based on fishtest data - // for two parameters needed to transform eval to the argument of a - // logistic function. - double as[] = {-8.24404295, 64.23892342, -95.73056462, 153.86478679}; - double bs[] = {-3.37154371, 28.44489198, -56.67657741, 72.05858751}; - double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; - double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; - - // Transform eval to centipawns with limited range - double x = Utility::clamp(double(100 * v) / PawnValueEg, -1000.0, 1000.0); - // Return win rate in per mille (rounded to nearest) - return int(0.5 + 1000 / (1 + std::exp((a - x) / b))); + return int(0.5 + UCI::win_rate_model_double(v, ply)); } } // namespace +// The win rate model returns the probability (per mille) of winning given an eval +// and a game-ply. The model fits rather accurately the LTC fishtest statistics. +double UCI::win_rate_model_double(double v, int ply) { + + // The model captures only up to 240 plies, so limit input (and rescale) + double m = std::min(240, ply) / 64.0; + + // Coefficients of a 3rd order polynomial fit based on fishtest data + // for two parameters needed to transform eval to the argument of a + // logistic function. + double as[] = {-8.24404295, 64.23892342, -95.73056462, 153.86478679}; + double bs[] = {-3.37154371, 28.44489198, -56.67657741, 72.05858751}; + double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3]; + double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3]; + + // Transform eval to centipawns with limited range + double x = Utility::clamp(double(100 * v) / PawnValueEg, -1000.0, 1000.0); + + // Return win rate in per mille + return 1000.0 / (1 + std::exp((a - x) / b)); +} + // -------------------- // Call qsearch(),search() directly for testing // -------------------- diff --git a/src/uci.h b/src/uci.h index 27a50fb9..c0e8372f 100644 --- a/src/uci.h +++ b/src/uci.h @@ -72,6 +72,7 @@ std::string square(Square s); std::string move(Move m, bool chess960); std::string pv(const Position& pos, Depth depth, Value alpha, Value beta); std::string wdl(Value v, int ply); +double win_rate_model_double(double v, int ply); Move to_move(const Position& pos, std::string& str); } // namespace UCI