From 5a58eb803a2c1b11808c96a1d8eb9c58a01d4791 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Sun, 29 Nov 2020 11:55:00 +0100 Subject: [PATCH] Loss func with autograd --- src/learn/learn.cpp | 19 +++++++++++++++++-- src/learn/learn.h | 2 ++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index f7358f8e..411cee08 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -19,6 +19,7 @@ #include "learn.h" +#include "autograd.h" #include "sfen_reader.h" #include "misc.h" @@ -320,6 +321,20 @@ namespace Learner return std::clamp(grad, -max_grad, max_grad); } + static ValueWithGrad get_loss(Value shallow, Value teacher_signal, int result, int ply) + { + using namespace Learner::Autograd::UnivariateStatic; + + auto q_ = sigmoid(VariableParameter{} * winning_probability_coefficient); + auto p_ = sigmoid(ConstantParameter{} * winning_probability_coefficient); + auto t_ = (ConstantParameter{} + 1.0) * 0.5; + auto lambda_ = ConstantParameter{}; + auto loss_ = pow(lambda_ * (q_ - p_) + (1.0 - lambda_) * (q_ - t_), 2.0); + + auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal)); + return loss_.eval(args); + } + // Calculate cross entropy during learning // The individual cross entropy of the win/loss term and win // rate term of the elmo expression is returned @@ -702,7 +717,7 @@ namespace Learner { goto RETRY_READ; } - + // We want to position being trained on not to be terminal if (MoveList(pos).size() == 0) goto RETRY_READ; @@ -720,7 +735,7 @@ namespace Learner // should be no real issues happening since // the read/write phases are isolated. atomic_thread_fence(memory_order_seq_cst); - Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad); + Eval::NNUE::update_parameters(Threads, epoch, params.verbose, params.learning_rate, calc_grad, get_loss); atomic_thread_fence(memory_order_seq_cst); if (++save_count * params.mini_batch_size >= params.eval_save_interval) diff --git a/src/learn/learn.h b/src/learn/learn.h index 6ce476e5..f74fd4e3 100644 --- a/src/learn/learn.h +++ b/src/learn/learn.h @@ -33,6 +33,7 @@ using LearnFloatType = float; // Definition of struct used in Learner // ---------------------- +#include "autograd.h" #include "packed_sfen.h" #include "position.h" @@ -68,6 +69,7 @@ namespace Learner void learn(std::istringstream& is); using CalcGradFunc = double(Value, Value, int, int); + using CalcLossFunc = ValueWithGrad(Value, Value, int, int); } #endif // ifndef _LEARN_H_