From b71d1e86205505997106348afa7e359b9f6593c1 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Sun, 29 Nov 2020 11:55:15 +0100 Subject: [PATCH] Pass the new loss function to update_parameters --- src/nnue/evaluate_nnue_learner.cpp | 5 ++++- src/nnue/evaluate_nnue_learner.h | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/nnue/evaluate_nnue_learner.cpp b/src/nnue/evaluate_nnue_learner.cpp index 6e0572dd..822c56b4 100644 --- a/src/nnue/evaluate_nnue_learner.cpp +++ b/src/nnue/evaluate_nnue_learner.cpp @@ -195,8 +195,11 @@ namespace Eval::NNUE { uint64_t epoch, bool verbose, double learning_rate, - Learner::CalcGradFunc calc_grad) + Learner::CalcGradFunc calc_grad, + Learner::CalcLossFunc calc_loss) { + using namespace Learner::Autograd::UnivariateStatic; + assert(batch_size > 0); learning_rate /= batch_size; diff --git a/src/nnue/evaluate_nnue_learner.h b/src/nnue/evaluate_nnue_learner.h index 8633f713..0fe8afce 100644 --- a/src/nnue/evaluate_nnue_learner.h +++ b/src/nnue/evaluate_nnue_learner.h @@ -38,7 +38,8 @@ namespace Eval::NNUE { uint64_t epoch, bool verbose, double learning_rate, - Learner::CalcGradFunc calc_grad); + Learner::CalcGradFunc calc_grad, + Learner::CalcLossFunc calc_loss); // Check if there are any problems with learning void check_health();