Make all grad related functions in learn static. Pass calc_grad as a parameter.

This commit is contained in:
Tomasz Sobczyk
2020-10-24 23:29:32 +02:00
committed by nodchip
parent e4868cb59e
commit cde6ec2bf2
4 changed files with 23 additions and 30 deletions

View File

@@ -185,7 +185,7 @@ namespace Learner
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage(double value)
static double winning_percentage(double value)
{
// 1/(1+10^(-Eval/4))
// = 1/(1+e^(-Eval/4*ln(10))
@@ -194,7 +194,7 @@ namespace Learner
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage_wdl(double value, int ply)
static double winning_percentage_wdl(double value, int ply)
{
constexpr double wdl_total = 1000.0;
constexpr double draw_score = 0.5;
@@ -207,7 +207,7 @@ namespace Learner
}
// A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage(double value, int ply)
static double winning_percentage(double value, int ply)
{
if (use_wdl)
{
@@ -219,7 +219,7 @@ namespace Learner
}
}
double calc_cross_entropy_of_winning_percentage(
static double calc_cross_entropy_of_winning_percentage(
double deep_win_rate,
double shallow_eval,
int ply)
@@ -229,7 +229,7 @@ namespace Learner
return -p * std::log(q) - (1.0 - p) * std::log(1.0 - q);
}
double calc_d_cross_entropy_of_winning_percentage(
static double calc_d_cross_entropy_of_winning_percentage(
double deep_win_rate,
double shallow_eval,
int ply)
@@ -248,7 +248,7 @@ namespace Learner
}
// Training Formula · Issue #71 · nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
double get_scaled_signal(double signal)
static double get_scaled_signal(double signal)
{
double scaled_signal = signal;
@@ -266,13 +266,13 @@ namespace Learner
}
// Teacher winning probability.
double calculate_p(double teacher_signal, int ply)
static double calculate_p(double teacher_signal, int ply)
{
const double scaled_teacher_signal = get_scaled_signal(teacher_signal);
return winning_percentage(scaled_teacher_signal, ply);
}
double calculate_lambda(double teacher_signal)
static double calculate_lambda(double teacher_signal)
{
// If the evaluation value in deep search exceeds elmo_lambda_limit
// then apply elmo_lambda_high instead of elmo_lambda_low.
@@ -284,7 +284,7 @@ namespace Learner
return lambda;
}
double calculate_t(int game_result)
static double calculate_t(int game_result)
{
// Use 1 as the correction term if the expected win rate is 1,
// 0 if you lose, and 0.5 if you draw.
@@ -294,20 +294,20 @@ namespace Learner
return t;
}
double calc_grad(Value teacher_signal, Value shallow, const PackedSfenValue& psv)
static double calc_grad(Value shallow, Value teacher_signal, int result, int ply)
{
// elmo (WCSC27) method
// Correct with the actual game wins and losses.
const double q = winning_percentage(shallow, psv.gamePly);
const double p = calculate_p(teacher_signal, psv.gamePly);
const double t = calculate_t(psv.game_result);
const double q = winning_percentage(shallow, ply);
const double p = calculate_p(teacher_signal, ply);
const double t = calculate_t(result);
const double lambda = calculate_lambda(teacher_signal);
double grad;
if (use_wdl)
{
const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly);
const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly);
const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, ply);
const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, ply);
grad = lambda * dce_p + (1.0 - lambda) * dce_t;
}
else
@@ -324,7 +324,7 @@ namespace Learner
// The individual cross entropy of the win/loss term and win
// rate term of the elmo expression is returned
// to the arguments cross_entropy_eval and cross_entropy_win.
Loss calc_cross_entropy(
static Loss calc_cross_entropy(
Value teacher_signal,
Value shallow,
const PackedSfenValue& psv)
@@ -360,12 +360,6 @@ namespace Learner
return loss;
}
// Other objective functions may be considered in the future...
double calc_grad(Value shallow, const PackedSfenValue& psv)
{
return calc_grad((Value)psv.score, shallow, psv);
}
// Class to generate sfen with multiple threads
struct LearnerThink
{
@@ -703,7 +697,7 @@ namespace Learner
// should be no real issues happening since
// the read/write phases are isolated.
atomic_thread_fence(memory_order_seq_cst);
Eval::NNUE::update_parameters(epoch, params.verbose);
Eval::NNUE::update_parameters(epoch, params.verbose, calc_grad);
atomic_thread_fence(memory_order_seq_cst);
if (++save_count * params.mini_batch_size >= params.eval_save_interval)