diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index b0f77e89..6cd54b13 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -56,8 +56,6 @@ #include #endif -extern double global_learning_rate; - using namespace std; template @@ -399,6 +397,8 @@ namespace Learner bool use_draw_games_in_validation = true; bool skip_duplicated_positions_in_training = true; + double learning_rate = 1.0; + string validation_set_file_name; string seed; @@ -697,7 +697,7 @@ namespace Learner // should be no real issues happening since // the read/write phases are isolated. atomic_thread_fence(memory_order_seq_cst); - Eval::NNUE::update_parameters(epoch, params.verbose, calc_grad); + Eval::NNUE::update_parameters(epoch, params.verbose, params.learning_rate, calc_grad); atomic_thread_fence(memory_order_seq_cst); if (++save_count * params.mini_batch_size >= params.eval_save_interval) @@ -737,7 +737,7 @@ namespace Learner << ", epoch " << epoch << endl; - out << " - learning rate = " << global_learning_rate << endl; + out << " - learning rate = " << params.learning_rate << endl; // For calculation of verification data loss AtomicLoss test_loss_sum{}; @@ -913,7 +913,7 @@ namespace Learner if (tot >= last_lr_drop + params.auto_lr_drop) { last_lr_drop = tot; - global_learning_rate *= params.newbob_decay; + params.learning_rate *= params.newbob_decay; } } else if (latest_loss < best_loss) @@ -929,11 +929,11 @@ namespace Learner if (--trials > 0 && !is_final) { cout - << " - reducing learning rate from " << global_learning_rate - << " to " << (global_learning_rate * params.newbob_decay) + << " - reducing learning rate from " << params.learning_rate + << " to " << (params.learning_rate * params.newbob_decay) << " (" << trials << " more trials)" << endl; - global_learning_rate *= params.newbob_decay; + params.learning_rate *= params.newbob_decay; } } @@ -961,8 +961,6 @@ namespace Learner string base_dir; string target_dir; - global_learning_rate = 1.0; - uint64_t nn_batch_size = 1000; string nn_options; @@ -1003,7 +1001,7 @@ namespace Learner else if (option == "batchsize") is >> params.mini_batch_size; // learning rate - else if (option == "lr") is >> global_learning_rate; + else if (option == "lr") is >> params.learning_rate; // Accept also the old option name. else if (option == "use_draw_in_training" @@ -1115,7 +1113,7 @@ namespace Learner out << " - nn_batch_size : " << nn_batch_size << endl; out << " - nn_options : " << nn_options << endl; - out << " - learning rate : " << global_learning_rate << endl; + out << " - learning rate : " << params.learning_rate << endl; out << " - use draws in training : " << params.use_draw_games_in_training << endl; out << " - use draws in validation : " << params.use_draw_games_in_validation << endl; out << " - skip repeated positions : " << params.skip_duplicated_positions_in_training << endl; diff --git a/src/nnue/evaluate_nnue_learner.cpp b/src/nnue/evaluate_nnue_learner.cpp index 3e91a7de..2a1fd6cb 100644 --- a/src/nnue/evaluate_nnue_learner.cpp +++ b/src/nnue/evaluate_nnue_learner.cpp @@ -18,9 +18,6 @@ #include "misc.h" #include "thread_win32_osx.h" -// Learning rate scale -double global_learning_rate; - // Code for learning NNUE evaluation function namespace Eval::NNUE { @@ -181,11 +178,15 @@ namespace Eval::NNUE { } // update the evaluation function parameters - void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad) { + void update_parameters( + uint64_t epoch, + bool verbose, + double learning_rate, + Learner::CalcGradFunc calc_grad) + { assert(batch_size > 0); - const auto learning_rate = static_cast( - global_learning_rate / batch_size); + learning_rate /= batch_size; std::lock_guard lock(examples_mutex); std::shuffle(examples.begin(), examples.end(), rng); diff --git a/src/nnue/evaluate_nnue_learner.h b/src/nnue/evaluate_nnue_learner.h index 8a9786e5..d350691b 100644 --- a/src/nnue/evaluate_nnue_learner.h +++ b/src/nnue/evaluate_nnue_learner.h @@ -31,7 +31,11 @@ namespace Eval::NNUE { double weight); // update the evaluation function parameters - void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad); + void update_parameters( + uint64_t epoch, + bool verbose, + double learning_rate, + Learner::CalcGradFunc calc_grad); // Check if there are any problems with learning void check_health();