mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
Loss func with autograd
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
|
||||
#include "learn.h"
|
||||
|
||||
#include "autograd.h"
|
||||
#include "sfen_reader.h"
|
||||
|
||||
#include "misc.h"
|
||||
@@ -320,6 +321,20 @@ namespace Learner
|
||||
return std::clamp(grad, -max_grad, max_grad);
|
||||
}
|
||||
|
||||
static ValueWithGrad<double> get_loss(Value shallow, Value teacher_signal, int result, int ply)
|
||||
{
|
||||
using namespace Learner::Autograd::UnivariateStatic;
|
||||
|
||||
auto q_ = sigmoid(VariableParameter<double, 0>{} * winning_probability_coefficient);
|
||||
auto p_ = sigmoid(ConstantParameter<double, 1>{} * winning_probability_coefficient);
|
||||
auto t_ = (ConstantParameter<double, 2>{} + 1.0) * 0.5;
|
||||
auto lambda_ = ConstantParameter<double, 3>{};
|
||||
auto loss_ = pow(lambda_ * (q_ - p_) + (1.0 - lambda_) * (q_ - t_), 2.0);
|
||||
|
||||
auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal));
|
||||
return loss_.eval(args);
|
||||
}
|
||||
|
||||
// Calculate cross entropy during learning
|
||||
// The individual cross entropy of the win/loss term and win
|
||||
// rate term of the elmo expression is returned
|
||||
@@ -702,7 +717,7 @@ namespace Learner
|
||||
{
|
||||
goto RETRY_READ;
|
||||
}
|
||||
|
||||
|
||||
// We want to position being trained on not to be terminal
|
||||
if (MoveList<LEGAL>(pos).size() == 0)
|
||||
goto RETRY_READ;
|
||||
@@ -720,7 +735,7 @@ namespace Learner
|
||||
// should be no real issues happening since
|
||||
// the read/write phases are isolated.
|
||||
atomic_thread_fence(memory_order_seq_cst);
|
||||
Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad);
|
||||
Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad, get_loss);
|
||||
atomic_thread_fence(memory_order_seq_cst);
|
||||
|
||||
if (++save_count * params.mini_batch_size >= params.eval_save_interval)
|
||||
|
||||
@@ -33,6 +33,7 @@ using LearnFloatType = float;
|
||||
// Definition of struct used in Learner
|
||||
// ----------------------
|
||||
|
||||
#include "autograd.h"
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "position.h"
|
||||
@@ -68,6 +69,7 @@ namespace Learner
|
||||
void learn(std::istringstream& is);
|
||||
|
||||
using CalcGradFunc = double(Value, Value, int, int);
|
||||
using CalcLossFunc = ValueWithGrad<double>(Value, Value, int, int);
|
||||
}
|
||||
|
||||
#endif // ifndef _LEARN_H_
|
||||
|
||||
Reference in New Issue
Block a user