Use winning_percentage_wdl in learn

This commit is contained in:
tttak
2020-08-24 22:56:08 +09:00
parent 7ee8a2bbb7
commit 4ce30d9522
3 changed files with 80 additions and 21 deletions

View File

@@ -133,6 +133,8 @@ double dest_score_max_value = 1.0;
// probabilities in the trainer. Sometimes we want to use the winning probabilities in the training
// data directly. In those cases, we set false to this variable.
bool convert_teacher_signal_to_winning_probability = true;
// Using WDL with win rate model instead of sigmoid
bool use_wdl = false;
// -----------------------------------
// write phase file
@@ -1162,6 +1164,45 @@ double winning_percentage(double value)
// = sigmoid(Eval/4*ln(10))
return sigmoid(value * winning_probability_coefficient);
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage_wdl(double value, int ply)
{
double wdl_w = UCI::win_rate_model_double( value, ply);
double wdl_l = UCI::win_rate_model_double(-value, ply);
double wdl_d = 1000.0 - wdl_w - wdl_l;
return (wdl_w + wdl_d / 2.0) / 1000.0;
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage(double value, int ply)
{
if (use_wdl) {
return winning_percentage_wdl(value, ply);
}
else {
return winning_percentage(value);
}
}
double calc_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply)
{
double p = deep_win_rate;
double q = winning_percentage(shallow_eval, ply);
return -p * std::log(q) - (1 - p) * std::log(1 - q);
}
double calc_d_cross_entropy_of_winning_percentage(double deep_win_rate, double shallow_eval, int ply)
{
constexpr double epsilon = 0.000001;
double y1 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval , ply);
double y2 = calc_cross_entropy_of_winning_percentage(deep_win_rate, shallow_eval + epsilon, ply);
// Divide by the winning_probability_coefficient to match scale with the sigmoidal win rate
return ((y2 - y1) / epsilon) / winning_probability_coefficient;
}
double dsigmoid(double x)
{
// Sigmoid function
@@ -1263,11 +1304,11 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps
// Scale to [dest_score_min_value, dest_score_max_value].
scaled_teacher_signal = scaled_teacher_signal * (dest_score_max_value - dest_score_min_value) + dest_score_min_value;
const double q = winning_percentage(shallow);
const double q = winning_percentage(shallow, psv.gamePly);
// Teacher winning probability.
double p = scaled_teacher_signal;
if (convert_teacher_signal_to_winning_probability) {
p = winning_percentage(scaled_teacher_signal);
p = winning_percentage(scaled_teacher_signal, psv.gamePly);
}
// Use 1 as the correction term if the expected win rate is 1, 0 if you lose, and 0.5 if you draw.
@@ -1277,9 +1318,17 @@ double calc_grad(Value teacher_signal, Value shallow , const PackedSfenValue& ps
// If the evaluation value in deep search exceeds ELMO_LAMBDA_LIMIT, apply ELMO_LAMBDA2 instead of ELMO_LAMBDA.
const double lambda = (abs(teacher_signal) >= ELMO_LAMBDA_LIMIT) ? ELMO_LAMBDA2 : ELMO_LAMBDA;
// Use the actual win rate as a correction term.
// This is the idea of elmo (WCSC27), modern O-parts.
const double grad = lambda * (q - p) + (1.0 - lambda) * (q - t);
double grad;
if (use_wdl) {
double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly);
double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly);
grad = lambda * dce_p + (1.0 - lambda) * dce_t;
}
else {
// Use the actual win rate as a correction term.
// This is the idea of elmo (WCSC27), modern O-parts.
grad = lambda * (q - p) + (1.0 - lambda) * (q - t);
}
return grad;
}
@@ -3168,6 +3217,8 @@ void learn(Position&, istringstream& is)
else if (option == "winning_probability_coefficient") is >> winning_probability_coefficient;
// Discount rate
else if (option == "discount_rate") is >> discount_rate;
// Using WDL with win rate model instead of sigmoid
else if (option == "use_wdl") is >> use_wdl;
// No learning of KK/KKP/KPP/KPPP.
else if (option == "freeze_kk") is >> freeze[0];

View File

@@ -238,27 +238,34 @@ namespace {
// The win rate model returns the probability (per mille) of winning given an eval
// and a game-ply. The model fits rather accurately the LTC fishtest statistics.
int win_rate_model(Value v, int ply) {
// The model captures only up to 240 plies, so limit input (and rescale)
double m = std::min(240, ply) / 64.0;
// Coefficients of a 3rd order polynomial fit based on fishtest data
// for two parameters needed to transform eval to the argument of a
// logistic function.
double as[] = {-8.24404295, 64.23892342, -95.73056462, 153.86478679};
double bs[] = {-3.37154371, 28.44489198, -56.67657741, 72.05858751};
double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3];
double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3];
// Transform eval to centipawns with limited range
double x = Utility::clamp(double(100 * v) / PawnValueEg, -1000.0, 1000.0);
// Return win rate in per mille (rounded to nearest)
return int(0.5 + 1000 / (1 + std::exp((a - x) / b)));
return int(0.5 + UCI::win_rate_model_double(v, ply));
}
} // namespace
// The win rate model returns the probability (per mille) of winning given an eval
// and a game-ply. The model fits rather accurately the LTC fishtest statistics.
double UCI::win_rate_model_double(double v, int ply) {
// The model captures only up to 240 plies, so limit input (and rescale)
double m = std::min(240, ply) / 64.0;
// Coefficients of a 3rd order polynomial fit based on fishtest data
// for two parameters needed to transform eval to the argument of a
// logistic function.
double as[] = {-8.24404295, 64.23892342, -95.73056462, 153.86478679};
double bs[] = {-3.37154371, 28.44489198, -56.67657741, 72.05858751};
double a = (((as[0] * m + as[1]) * m + as[2]) * m) + as[3];
double b = (((bs[0] * m + bs[1]) * m + bs[2]) * m) + bs[3];
// Transform eval to centipawns with limited range
double x = Utility::clamp(double(100 * v) / PawnValueEg, -1000.0, 1000.0);
// Return win rate in per mille
return 1000.0 / (1 + std::exp((a - x) / b));
}
// --------------------
// Call qsearch(),search() directly for testing
// --------------------

View File

@@ -72,6 +72,7 @@ std::string square(Square s);
std::string move(Move m, bool chess960);
std::string pv(const Position& pos, Depth depth, Value alpha, Value beta);
std::string wdl(Value v, int ply);
double win_rate_model_double(double v, int ply);
Move to_move(const Position& pos, std::string& str);
} // namespace UCI