Loss func with autograd

This commit is contained in:
Tomasz Sobczyk
2020-11-29 11:55:00 +01:00
committed by nodchip
parent 541fb8177a
commit 5a58eb803a
2 changed files with 19 additions and 2 deletions

View File

@@ -19,6 +19,7 @@
#include "learn.h"
#include "autograd.h"
#include "sfen_reader.h"
#include "misc.h"
@@ -320,6 +321,20 @@ namespace Learner
return std::clamp(grad, -max_grad, max_grad);
}
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
{
using namespace Learner::Autograd::UnivariateStatic;
auto q_ = sigmoid(VariableParameter<double, 0>{} * winning_probability_coefficient);
auto p_ = sigmoid(ConstantParameter<double, 1>{} * winning_probability_coefficient);
auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
auto lambda_ = ConstantParameter<double, 3>{};
auto loss_ = pow(lambda_ * (q_ - p_) + (1.0 - lambda_) * (q_ - t_), 2.0);
auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal));
return loss_.eval(args);
}
// Calculate cross entropy during learning
// The individual cross entropy of the win/loss term and win
// rate term of the elmo expression is returned
@@ -702,7 +717,7 @@ namespace Learner
{
goto RETRY_READ;
}
// We want to position being trained on not to be terminal
if (MoveList<LEGAL>(pos).size() == 0)
goto RETRY_READ;
@@ -720,7 +735,7 @@ namespace Learner
// should be no real issues happening since
// the read/write phases are isolated.
atomic_thread_fence(memory_order_seq_cst);
Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad);
Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad, get_loss);
atomic_thread_fence(memory_order_seq_cst);
if (++save_count * params.mini_batch_size >= params.eval_save_interval)

View File

@@ -33,6 +33,7 @@ using LearnFloatType = float;
// Definition of struct used in Learner
// ----------------------
#include "autograd.h"
#include "packed_sfen.h"
#include "position.h"
@@ -68,6 +69,7 @@ namespace Learner
void learn(std::istringstream& is);
using CalcGradFunc = double(Value, Value, int, int);
using CalcLossFunc = ValueWithGrad<double>(Value, Value, int, int);
}
#endif // ifndef _LEARN_H_