Replace global_learning_rate with learning_rate local to the learner and passed to update_parameters as a parameter.

This commit is contained in:
Tomasz Sobczyk
2020-10-24 23:35:34 +02:00
committed by nodchip
parent cde6ec2bf2
commit f81fa3d712
3 changed files with 22 additions and 19 deletions

View File

@@ -56,8 +56,6 @@
#include <omp.h>
#endif
extern double global_learning_rate;
using namespace std;
template <typename T>
@@ -399,6 +397,8 @@ namespace Learner
bool use_draw_games_in_validation = true;
bool skip_duplicated_positions_in_training = true;
double learning_rate = 1.0;
string validation_set_file_name;
string seed;
@@ -697,7 +697,7 @@ namespace Learner
// should be no real issues happening since
// the read/write phases are isolated.
atomic_thread_fence(memory_order_seq_cst);
Eval::NNUE::update_parameters(epoch, params.verbose, calc_grad);
Eval::NNUE::update_parameters(epoch, params.verbose, params.learning_rate, calc_grad);
atomic_thread_fence(memory_order_seq_cst);
if (++save_count * params.mini_batch_size >= params.eval_save_interval)
@@ -737,7 +737,7 @@ namespace Learner
<< ", epoch " << epoch
<< endl;
out << " - learning rate = " << global_learning_rate << endl;
out << " - learning rate = " << params.learning_rate << endl;
// For calculation of verification data loss
AtomicLoss test_loss_sum{};
@@ -913,7 +913,7 @@ namespace Learner
if (tot >= last_lr_drop + params.auto_lr_drop)
{
last_lr_drop = tot;
global_learning_rate *= params.newbob_decay;
params.learning_rate *= params.newbob_decay;
}
}
else if (latest_loss < best_loss)
@@ -929,11 +929,11 @@ namespace Learner
if (--trials > 0 && !is_final)
{
cout
<< " - reducing learning rate from " << global_learning_rate
<< " to " << (global_learning_rate * params.newbob_decay)
<< " - reducing learning rate from " << params.learning_rate
<< " to " << (params.learning_rate * params.newbob_decay)
<< " (" << trials << " more trials)" << endl;
global_learning_rate *= params.newbob_decay;
params.learning_rate *= params.newbob_decay;
}
}
@@ -961,8 +961,6 @@ namespace Learner
string base_dir;
string target_dir;
global_learning_rate = 1.0;
uint64_t nn_batch_size = 1000;
string nn_options;
@@ -1003,7 +1001,7 @@ namespace Learner
else if (option == "batchsize") is >> params.mini_batch_size;
// learning rate
else if (option == "lr") is >> global_learning_rate;
else if (option == "lr") is >> params.learning_rate;
// Accept also the old option name.
else if (option == "use_draw_in_training"
@@ -1115,7 +1113,7 @@ namespace Learner
out << " - nn_batch_size : " << nn_batch_size << endl;
out << " - nn_options : " << nn_options << endl;
out << " - learning rate : " << global_learning_rate << endl;
out << " - learning rate : " << params.learning_rate << endl;
out << " - use draws in training : " << params.use_draw_games_in_training << endl;
out << " - use draws in validation : " << params.use_draw_games_in_validation << endl;
out << " - skip repeated positions : " << params.skip_duplicated_positions_in_training << endl;

View File

@@ -18,9 +18,6 @@
#include "misc.h"
#include "thread_win32_osx.h"
// Learning rate scale
double global_learning_rate;
// Code for learning NNUE evaluation function
namespace Eval::NNUE {
@@ -181,11 +178,15 @@ namespace Eval::NNUE {
}
// update the evaluation function parameters
void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad) {
void update_parameters(
uint64_t epoch,
bool verbose,
double learning_rate,
Learner::CalcGradFunc calc_grad)
{
assert(batch_size > 0);
const auto learning_rate = static_cast<LearnFloatType>(
global_learning_rate / batch_size);
learning_rate /= batch_size;
std::lock_guard<std::mutex> lock(examples_mutex);
std::shuffle(examples.begin(), examples.end(), rng);

View File

@@ -31,7 +31,11 @@ namespace Eval::NNUE {
double weight);
// update the evaluation function parameters
void update_parameters(uint64_t epoch, bool verbose, Learner::CalcGradFunc calc_grad);
void update_parameters(
uint64_t epoch,
bool verbose,
double learning_rate,
Learner::CalcGradFunc calc_grad);
// Check if there are any problems with learning
void check_health();