diff --git a/docs/learn.md b/docs/learn.md index 30a7c951..e88de089 100644 --- a/docs/learn.md +++ b/docs/learn.md @@ -80,6 +80,8 @@ Currently the following options are available: `validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy) +`validation_count` - the number of positions to use for validation. Default: 2000. + `sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M) `thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000 diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 6651e096..90f629e1 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -482,6 +482,12 @@ namespace Learner // Mini batch size size. Be sure to set it on the side that uses this class. uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE; + // Number of phases used for calculation such as mse + // mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time. + // Since search() is performed with depth = 1 in calculation of + // move match rate, simple comparison is not possible... + uint64_t validation_count = 2000; + // Option to exclude early stage from learning int reduction_gameply = 1; @@ -550,16 +556,10 @@ namespace Learner } }; - // Number of phases used for calculation such as mse - // mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time. - // Since search() is performed with depth = 1 in calculation of - // move match rate, simple comparison is not possible... - static constexpr uint64_t sfen_for_mse_size = 2000; - LearnerThink(const Params& prm) : params(prm), prng(prm.seed), - sr( + train_sr( prm.filenames, prm.shuffle, SfenReaderMode::Cyclic, @@ -567,6 +567,14 @@ namespace Learner std::to_string(prng.next_random_seed()), prm.sfen_read_size, prm.thread_buffer_size), + validation_sr( + prm.validation_set_file_name.empty() ? prm.filenames : std::vector{ prm.validation_set_file_name }, + prm.shuffle, + SfenReaderMode::Cyclic, + 1, + std::to_string(prng.next_random_seed()), + prm.sfen_read_size, + prm.thread_buffer_size), learn_loss_sum{} { save_count = 0; @@ -612,7 +620,8 @@ namespace Learner PRNG prng; // sfen reader - SfenReader sr; + SfenReader train_sr; + SfenReader validation_sr; uint64_t save_count; uint64_t loss_output_count; @@ -666,28 +675,26 @@ namespace Learner Eval::NNUE::verify_any_net_loaded(); - const PSVector sfen_for_mse = - params.validation_set_file_name.empty() - ? sr.read_for_mse(sfen_for_mse_size) - : sr.read_validation_set( - params.validation_set_file_name, + const PSVector validation_data = + validation_sr.read_some( + params.validation_count, params.eval_limit, - params.use_draw_games_in_validation); + params.use_draw_games_in_validation + ); - if (params.validation_set_file_name.empty() - && sfen_for_mse.size() != sfen_for_mse_size) + if (validation_data.size() != params.validation_count) { auto out = sync_region_cout.new_region(); out - << "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size() - << " out of " << sfen_for_mse_size << '\n'; + << "INFO (learn): Error reading validation data. Read " << validation_data.size() + << " out of " << params.validation_count << '\n'; return; } if (params.newbob_decay != 1.0) { - calc_loss(sfen_for_mse, 0); + calc_loss(validation_data, 0); best_loss = latest_loss_sum / latest_loss_count; latest_loss_sum = 0.0; @@ -714,7 +721,7 @@ namespace Learner if (stop_flag) break; - update_weights(sfen_for_mse, epoch); + update_weights(validation_data, epoch); if (stop_flag) break; @@ -742,7 +749,7 @@ namespace Learner RETRY_READ:; - if (!sr.read_to_thread_buffer(thread_id, ps)) + if (!train_sr.read_to_thread_buffer(thread_id, ps)) { // If we ran out of data we stop completely // because there's nothing left to do. @@ -1146,6 +1153,7 @@ namespace Learner is >> filename; params.filenames.push_back(filename); } + else if (option == "validation_count") is >> params.validation_count; // Specify the number of loops else if (option == "epochs") is >> epochs; @@ -1260,6 +1268,7 @@ namespace Learner out << " - validation set : " << params.validation_set_file_name << endl; } + out << " - validation count : " << params.validation_count << endl; out << " - epochs : " << epochs << endl; out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl; out << " - eval_limit : " << params.eval_limit << endl; diff --git a/src/learn/sfen_reader.h b/src/learn/sfen_reader.h index 512f1165..206ed2bd 100644 --- a/src/learn/sfen_reader.h +++ b/src/learn/sfen_reader.h @@ -73,10 +73,10 @@ namespace Learner{ } // Load the phase for calculation such as mse. - PSVector read_for_mse(uint64_t count) + PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games) { - PSVector sfen_for_mse; - sfen_for_mse.reserve(count); + PSVector psv; + psv.reserve(count); for (uint64_t i = 0; i < count; ++i) { @@ -84,43 +84,19 @@ namespace Learner{ if (!read_to_thread_buffer(0, ps)) { std::cout << "ERROR (sfen_reader): Reading failed." << std::endl; - return sfen_for_mse; + return psv; } - sfen_for_mse.push_back(ps); + if (eval_limit < abs(ps.score)) + continue; + + if (!use_draw_games && ps.game_result == 0) + continue; + + psv.push_back(ps); } - return sfen_for_mse; - } - - PSVector read_validation_set(const std::string& file_name, int eval_limit, bool use_draw_games) - { - PSVector sfen_for_mse; - - auto input = open_sfen_input_file(file_name); - - while(!input->eof()) - { - std::optional p_opt = input->next(); - if (p_opt.has_value()) - { - auto& p = *p_opt; - - if (eval_limit < abs(p.score)) - continue; - - if (!use_draw_games && p.game_result == 0) - continue; - - sfen_for_mse.push_back(p); - } - else - { - break; - } - } - - return sfen_for_mse; + return psv; } // [ASYNC] Thread returns one aspect. Otherwise returns false.