mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Make all grad related functions in learn static. Pass calc_grad as a parameter.
This commit is contained in:
@@ -18,8 +18,6 @@
|
||||
#include "misc.h"
|
||||
#include "thread_win32_osx.h"
|
||||
|
||||
#include "learn/learn.h"
|
||||
|
||||
// Learning rate scale
|
||||
double global_learning_rate;
|
||||
|
||||
@@ -183,7 +181,7 @@ namespace Eval::NNUE {
|
||||
}
|
||||
|
||||
// update the evaluation function parameters
|
||||
void update_parameters(uint64_t epoch, bool verbose) {
|
||||
void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad) {
|
||||
assert(batch_size > 0);
|
||||
|
||||
const auto learning_rate = static_cast<LearnFloatType>(
|
||||
@@ -210,7 +208,8 @@ namespace Eval::NNUE {
|
||||
batch[b].sign * network_output[b] * kPonanzaConstant));
|
||||
const auto discrete = batch[b].sign * batch[b].discrete_nn_eval;
|
||||
const auto& psv = batch[b].psv;
|
||||
const double gradient = batch[b].sign * Learner::calc_grad(shallow, psv);
|
||||
const double gradient =
|
||||
batch[b].sign * calc_grad(shallow, (Value)psv.score, psv.game_result, psv.gamePly);
|
||||
gradients[b] = static_cast<LearnFloatType>(gradient * batch[b].weight);
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user