diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 7f18ff28..3942b606 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -93,6 +93,8 @@ namespace Learner static double elmo_lambda_high = 1.0; static double elmo_lambda_limit = 32000; + static double max_grad = 1.0; + // Using stockfish's WDL with win rate model instead of sigmoid static bool use_wdl = false; @@ -315,7 +317,7 @@ namespace Learner grad = lambda * (q - p) + (1.0 - lambda) * (q - t); } - return grad; + return std::clamp(grad, -max_grad, max_grad); } // Calculate cross entropy during learning @@ -1072,6 +1074,7 @@ namespace Learner else if (option == "lambda") is >> elmo_lambda_low; else if (option == "lambda2") is >> elmo_lambda_high; else if (option == "lambda_limit") is >> elmo_lambda_limit; + else if (option == "max_grad") is >> max_grad; else if (option == "reduction_gameply") is >> params.reduction_gameply; @@ -1175,6 +1178,7 @@ namespace Learner out << " - elmo_lambda_low : " << elmo_lambda_low << endl; out << " - elmo_lambda_high : " << elmo_lambda_high << endl; out << " - elmo_lambda_limit : " << elmo_lambda_limit << endl; + out << " - max_grad : " << max_grad << endl; out << " - eval_save_interval : " << params.eval_save_interval << " sfens" << endl; out << " - loss_output_interval : " << params.loss_output_interval << " sfens" << endl;