From 6853b4aac2055321944c07aa0aa354f5d5803a17 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 21 Dec 2020 20:57:51 +0100 Subject: [PATCH] Simple filtering for validation data. --- docs/learn.md | 4 ++- src/learn/learn.cpp | 56 +++++++++++++++++++++++++++++++++++------ src/learn/sfen_reader.h | 17 +++++++------ 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/docs/learn.md b/docs/learn.md index e88de089..fe88e7e8 100644 --- a/docs/learn.md +++ b/docs/learn.md @@ -66,7 +66,9 @@ Currently the following options are available: `assume_quiet` - this is a flag option. When specified learn will not perform qsearch to reach a quiet position. -`smart_fen_skipping` - this is a flag option. When specified some position that are not good candidates for teaching are skipped. This includes positions where the best move is a capture or promotion, and position where a king is in check. +`smart_fen_skipping` - this is a flag option. When specified some position that are not good candidates for teaching are skipped. This includes positions where the best move is a capture or promotion, and position where a king is in check. Default: 1. + +`smart_fen_skipping_for_validation` - same as `smart_fen_skipping` but applies to validation data set. Default: 0. `newbob_num_trials` - determines after how many subsequent rejected nets the training process will be terminated. Default: 4. diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 90f629e1..c3499283 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -521,6 +521,7 @@ namespace Learner bool assume_quiet = false; bool smart_fen_skipping = false; + bool smart_fen_skipping_for_validation = false; double learning_rate = 1.0; double max_grad = 1.0; @@ -593,6 +594,8 @@ namespace Learner private: static void set_learning_search_limits(); + PSVector fetch_next_validation_set(); + void learn_worker(Thread& th, std::atomic& counter, uint64_t limit); void update_weights(const PSVector& psv, uint64_t epoch); @@ -665,6 +668,44 @@ namespace Learner limits.depth = 0; } + PSVector LearnerThink::fetch_next_validation_set() + { + PSVector validation_data; + + auto mainThread = Threads.main(); + mainThread->execute_with_worker([&validation_data, this](auto& th){ + auto do_include_predicate = [&th, this](const PackedSfenValue& ps) -> bool { + if (params.eval_limit < abs(ps.score)) + return false; + + if (!params.use_draw_games_in_validation && ps.game_result == 0) + return false; + + if (params.smart_fen_skipping_for_validation) + { + StateInfo si; + auto& pos = th.rootPos; + if (pos.set_from_packed_sfen(ps.sfen, &si, &th) != 0) + return false; + + if (pos.capture_or_promotion((Move)ps.move) || pos.checkers()) + return false; + } + + return true; + }; + + validation_data = validation_sr.read_some( + params.validation_count, + params.validation_count * 100, // to have a reasonable bound on the running time. + do_include_predicate + ); + }); + mainThread->wait_for_worker_finished(); + + return validation_data; + } + void LearnerThink::learn(uint64_t epochs) { #if defined(_OPENMP) @@ -675,19 +716,16 @@ namespace Learner Eval::NNUE::verify_any_net_loaded(); - const PSVector validation_data = - validation_sr.read_some( - params.validation_count, - params.eval_limit, - params.use_draw_games_in_validation - ); + const PSVector validation_data = fetch_next_validation_set(); if (validation_data.size() != params.validation_count) { auto out = sync_region_cout.new_region(); out << "INFO (learn): Error reading validation data. Read " << validation_data.size() - << " out of " << params.validation_count << '\n'; + << " out of " << params.validation_count << '\n' + << "INFO (learn): This either means that less than 1% of the validation data passed the filter" + << " or the file is empty\n"; return; } @@ -1235,6 +1273,7 @@ namespace Learner else if (option == "verbose") params.verbose = true; else if (option == "assume_quiet") params.assume_quiet = true; else if (option == "smart_fen_skipping") params.smart_fen_skipping = true; + else if (option == "smart_fen_skipping_for_validation") params.smart_fen_skipping_for_validation = true; else { out << "INFO: Unknown option: " << option << ". Ignoring.\n"; @@ -1306,6 +1345,9 @@ namespace Learner out << " - sfen_read_size : " << params.sfen_read_size << endl; out << " - thread_buffer_size : " << params.thread_buffer_size << endl; + out << " - smart_fen_skipping : " << params.smart_fen_skipping << endl; + out << " - smart_fen_skipping_val : " << params.smart_fen_skipping_for_validation << endl; + out << " - seed : " << params.seed << endl; out << " - verbose : " << (params.verbose ? "true" : "false") << endl; diff --git a/src/learn/sfen_reader.h b/src/learn/sfen_reader.h index 206ed2bd..1574f63a 100644 --- a/src/learn/sfen_reader.h +++ b/src/learn/sfen_reader.h @@ -15,6 +15,7 @@ #include #include #include +#include namespace Learner{ @@ -73,12 +74,12 @@ namespace Learner{ } // Load the phase for calculation such as mse. - PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games) + PSVector read_some(uint64_t count, uint64_t count_tries, std::function do_take) { PSVector psv; psv.reserve(count); - for (uint64_t i = 0; i < count; ++i) + for (uint64_t i = 0; i < count_tries; ++i) { PackedSfenValue ps; if (!read_to_thread_buffer(0, ps)) @@ -87,13 +88,13 @@ namespace Learner{ return psv; } - if (eval_limit < abs(ps.score)) - continue; + if (do_take(ps)) + { + psv.push_back(ps); - if (!use_draw_games && ps.game_result == 0) - continue; - - psv.push_back(ps); + if (psv.size() >= count) + break; + } } return psv;