Simple filtering for validation data.

This commit is contained in:
Tomasz Sobczyk
2020-12-21 20:57:51 +01:00
committed by nodchip
parent 50df3a7389
commit 6853b4aac2
3 changed files with 61 additions and 16 deletions

View File

@@ -521,6 +521,7 @@ namespace Learner
bool assume_quiet = false;
bool smart_fen_skipping = false;
bool smart_fen_skipping_for_validation = false;
double learning_rate = 1.0;
double max_grad = 1.0;
@@ -593,6 +594,8 @@ namespace Learner
private:
static void set_learning_search_limits();
PSVector fetch_next_validation_set();
void learn_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);
void update_weights(const PSVector& psv, uint64_t epoch);
@@ -665,6 +668,44 @@ namespace Learner
limits.depth = 0;
}
PSVector LearnerThink::fetch_next_validation_set()
{
PSVector validation_data;
auto mainThread = Threads.main();
mainThread->execute_with_worker([&validation_data, this](auto& th){
auto do_include_predicate = [&th, this](const PackedSfenValue& ps) -> bool {
if (params.eval_limit < abs(ps.score))
return false;
if (!params.use_draw_games_in_validation && ps.game_result == 0)
return false;
if (params.smart_fen_skipping_for_validation)
{
StateInfo si;
auto& pos = th.rootPos;
if (pos.set_from_packed_sfen(ps.sfen, &si, &th) != 0)
return false;
if (pos.capture_or_promotion((Move)ps.move) || pos.checkers())
return false;
}
return true;
};
validation_data = validation_sr.read_some(
params.validation_count,
params.validation_count * 100, // to have a reasonable bound on the running time.
do_include_predicate
);
});
mainThread->wait_for_worker_finished();
return validation_data;
}
void LearnerThink::learn(uint64_t epochs)
{
#if defined(_OPENMP)
@@ -675,19 +716,16 @@ namespace Learner
Eval::NNUE::verify_any_net_loaded();
const PSVector validation_data =
validation_sr.read_some(
params.validation_count,
params.eval_limit,
params.use_draw_games_in_validation
);
const PSVector validation_data = fetch_next_validation_set();
if (validation_data.size() != params.validation_count)
{
auto out = sync_region_cout.new_region();
out
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << params.validation_count << '\n';
<< " out of " << params.validation_count << '\n'
<< "INFO (learn): This either means that less than 1% of the validation data passed the filter"
<< " or the file is empty\n";
return;
}
@@ -1235,6 +1273,7 @@ namespace Learner
else if (option == "verbose") params.verbose = true;
else if (option == "assume_quiet") params.assume_quiet = true;
else if (option == "smart_fen_skipping") params.smart_fen_skipping = true;
else if (option == "smart_fen_skipping_for_validation") params.smart_fen_skipping_for_validation = true;
else
{
out << "INFO: Unknown option: " << option << ". Ignoring.\n";
@@ -1306,6 +1345,9 @@ namespace Learner
out << " - sfen_read_size : " << params.sfen_read_size << endl;
out << " - thread_buffer_size : " << params.thread_buffer_size << endl;
out << " - smart_fen_skipping : " << params.smart_fen_skipping << endl;
out << " - smart_fen_skipping_val : " << params.smart_fen_skipping_for_validation << endl;
out << " - seed : " << params.seed << endl;
out << " - verbose : " << (params.verbose ? "true" : "false") << endl;