diff --git a/src/nnue/evaluate_nnue_learner.cpp b/src/nnue/evaluate_nnue_learner.cpp index 6e0572dd..822c56b4 100644 --- a/src/nnue/evaluate_nnue_learner.cpp +++ b/src/nnue/evaluate_nnue_learner.cpp @@ -195,8 +195,11 @@ namespace Eval::NNUE { uint64_t epoch, bool verbose, double learning_rate, - Learner::CalcGradFunc calc_grad) + Learner::CalcGradFunc calc_grad, + Learner::CalcLossFunc calc_loss) { + using namespace Learner::Autograd::UnivariateStatic; + assert(batch_size > 0); learning_rate /= batch_size; diff --git a/src/nnue/evaluate_nnue_learner.h b/src/nnue/evaluate_nnue_learner.h index 8633f713..0fe8afce 100644 --- a/src/nnue/evaluate_nnue_learner.h +++ b/src/nnue/evaluate_nnue_learner.h @@ -38,7 +38,8 @@ namespace Eval::NNUE { uint64_t epoch, bool verbose, double learning_rate, - Learner::CalcGradFunc calc_grad); + Learner::CalcGradFunc calc_grad, + Learner::CalcLossFunc calc_loss); // Check if there are any problems with learning void check_health();