diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 5a540d31..3648a40f 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -744,6 +744,8 @@ namespace Learner newbob_decay = 1.0; newbob_num_trials = 2; + auto_lr_drop = 0; + last_lr_drop = 0; best_loss = std::numeric_limits::infinity(); latest_loss_sum = 0.0; latest_loss_count = 0; @@ -797,6 +799,8 @@ namespace Learner shared_timed_mutex nn_mutex; double newbob_decay; int newbob_num_trials; + uint64_t auto_lr_drop; + uint64_t last_lr_drop; double best_loss; double latest_loss_sum; uint64_t latest_loss_count; @@ -1295,7 +1299,21 @@ namespace Learner latest_loss_sum = 0.0; latest_loss_count = 0; cout << "loss: " << latest_loss; - if (latest_loss < best_loss) + auto tot = sr.total_done.load(); + if (auto_lr_drop) + { + cout << " < best (" << best_loss << "), accepted" << endl; + best_loss = latest_loss; + best_nn_directory = Path::Combine((std::string)Options["EvalSaveDir"], dir_name); + trials = newbob_num_trials; + + if (tot >= last_lr_drop + auto_lr_drop) + { + last_lr_drop = tot; + global_learning_rate *= newbob_decay; + } + } + else if (latest_loss < best_loss) { cout << " < best (" << best_loss << "), accepted" << endl; best_loss = latest_loss; @@ -1647,6 +1665,7 @@ namespace Learner uint64_t nn_batch_size = 1000; double newbob_decay = 0.5; int newbob_num_trials = 4; + uint64_t auto_lr_drop = 0; string nn_options; uint64_t eval_save_interval = LEARN_EVAL_SAVE_INTERVAL; @@ -1729,6 +1748,7 @@ namespace Learner else if (option == "newbob_decay") is >> newbob_decay; else if (option == "newbob_num_trials") is >> newbob_num_trials; else if (option == "nn_options") is >> nn_options; + else if (option == "auto_lr_drop") is >> auto_lr_drop; else if (option == "eval_save_interval") is >> eval_save_interval; else if (option == "loss_output_interval") is >> loss_output_interval; @@ -1972,6 +1992,7 @@ namespace Learner learn_think.newbob_decay = newbob_decay; learn_think.newbob_num_trials = newbob_num_trials; + learn_think.auto_lr_drop = auto_lr_drop; learn_think.eval_save_interval = eval_save_interval; learn_think.loss_output_interval = loss_output_interval;