mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
Add option "auto_lr_drop" that specifies the amount of positions from previous lr drop after which to reduce lr by newbob_decay.
This commit is contained in:
@@ -744,6 +744,8 @@ namespace Learner
|
||||
|
||||
newbob_decay = 1.0;
|
||||
newbob_num_trials = 2;
|
||||
auto_lr_drop = 0;
|
||||
last_lr_drop = 0;
|
||||
best_loss = std::numeric_limits<double>::infinity();
|
||||
latest_loss_sum = 0.0;
|
||||
latest_loss_count = 0;
|
||||
@@ -797,6 +799,8 @@ namespace Learner
|
||||
shared_timed_mutex nn_mutex;
|
||||
double newbob_decay;
|
||||
int newbob_num_trials;
|
||||
uint64_t auto_lr_drop;
|
||||
uint64_t last_lr_drop;
|
||||
double best_loss;
|
||||
double latest_loss_sum;
|
||||
uint64_t latest_loss_count;
|
||||
@@ -1295,7 +1299,21 @@ namespace Learner
|
||||
latest_loss_sum = 0.0;
|
||||
latest_loss_count = 0;
|
||||
cout << "loss: " << latest_loss;
|
||||
if (latest_loss < best_loss)
|
||||
auto tot = sr.total_done.load();
|
||||
if (auto_lr_drop)
|
||||
{
|
||||
cout << " < best (" << best_loss << "), accepted" << endl;
|
||||
best_loss = latest_loss;
|
||||
best_nn_directory = Path::Combine((std::string)Options["EvalSaveDir"], dir_name);
|
||||
trials = newbob_num_trials;
|
||||
|
||||
if (tot >= last_lr_drop + auto_lr_drop)
|
||||
{
|
||||
last_lr_drop = tot;
|
||||
global_learning_rate *= newbob_decay;
|
||||
}
|
||||
}
|
||||
else if (latest_loss < best_loss)
|
||||
{
|
||||
cout << " < best (" << best_loss << "), accepted" << endl;
|
||||
best_loss = latest_loss;
|
||||
@@ -1647,6 +1665,7 @@ namespace Learner
|
||||
uint64_t nn_batch_size = 1000;
|
||||
double newbob_decay = 0.5;
|
||||
int newbob_num_trials = 4;
|
||||
uint64_t auto_lr_drop = 0;
|
||||
string nn_options;
|
||||
|
||||
uint64_t eval_save_interval = LEARN_EVAL_SAVE_INTERVAL;
|
||||
@@ -1729,6 +1748,7 @@ namespace Learner
|
||||
else if (option == "newbob_decay") is >> newbob_decay;
|
||||
else if (option == "newbob_num_trials") is >> newbob_num_trials;
|
||||
else if (option == "nn_options") is >> nn_options;
|
||||
else if (option == "auto_lr_drop") is >> auto_lr_drop;
|
||||
|
||||
else if (option == "eval_save_interval") is >> eval_save_interval;
|
||||
else if (option == "loss_output_interval") is >> loss_output_interval;
|
||||
@@ -1972,6 +1992,7 @@ namespace Learner
|
||||
|
||||
learn_think.newbob_decay = newbob_decay;
|
||||
learn_think.newbob_num_trials = newbob_num_trials;
|
||||
learn_think.auto_lr_drop = auto_lr_drop;
|
||||
|
||||
learn_think.eval_save_interval = eval_save_interval;
|
||||
learn_think.loss_output_interval = loss_output_interval;
|
||||
|
||||
Reference in New Issue
Block a user