mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 11:36:51 +08:00
Simple filtering for validation data.
This commit is contained in:
@@ -521,6 +521,7 @@ namespace Learner
|
||||
|
||||
bool assume_quiet = false;
|
||||
bool smart_fen_skipping = false;
|
||||
bool smart_fen_skipping_for_validation = false;
|
||||
|
||||
double learning_rate = 1.0;
|
||||
double max_grad = 1.0;
|
||||
@@ -593,6 +594,8 @@ namespace Learner
|
||||
private:
|
||||
static void set_learning_search_limits();
|
||||
|
||||
PSVector fetch_next_validation_set();
|
||||
|
||||
void learn_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);
|
||||
|
||||
void update_weights(const PSVector& psv, uint64_t epoch);
|
||||
@@ -665,6 +668,44 @@ namespace Learner
|
||||
limits.depth = 0;
|
||||
}
|
||||
|
||||
PSVector LearnerThink::fetch_next_validation_set()
|
||||
{
|
||||
PSVector validation_data;
|
||||
|
||||
auto mainThread = Threads.main();
|
||||
mainThread->execute_with_worker([&validation_data, this](auto& th){
|
||||
auto do_include_predicate = [&th, this](const PackedSfenValue& ps) -> bool {
|
||||
if (params.eval_limit < abs(ps.score))
|
||||
return false;
|
||||
|
||||
if (!params.use_draw_games_in_validation && ps.game_result == 0)
|
||||
return false;
|
||||
|
||||
if (params.smart_fen_skipping_for_validation)
|
||||
{
|
||||
StateInfo si;
|
||||
auto& pos = th.rootPos;
|
||||
if (pos.set_from_packed_sfen(ps.sfen, &si, &th) != 0)
|
||||
return false;
|
||||
|
||||
if (pos.capture_or_promotion((Move)ps.move) || pos.checkers())
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
validation_data = validation_sr.read_some(
|
||||
params.validation_count,
|
||||
params.validation_count * 100, // to have a reasonable bound on the running time.
|
||||
do_include_predicate
|
||||
);
|
||||
});
|
||||
mainThread->wait_for_worker_finished();
|
||||
|
||||
return validation_data;
|
||||
}
|
||||
|
||||
void LearnerThink::learn(uint64_t epochs)
|
||||
{
|
||||
#if defined(_OPENMP)
|
||||
@@ -675,19 +716,16 @@ namespace Learner
|
||||
|
||||
Eval::NNUE::verify_any_net_loaded();
|
||||
|
||||
const PSVector validation_data =
|
||||
validation_sr.read_some(
|
||||
params.validation_count,
|
||||
params.eval_limit,
|
||||
params.use_draw_games_in_validation
|
||||
);
|
||||
const PSVector validation_data = fetch_next_validation_set();
|
||||
|
||||
if (validation_data.size() != params.validation_count)
|
||||
{
|
||||
auto out = sync_region_cout.new_region();
|
||||
out
|
||||
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
|
||||
<< " out of " << params.validation_count << '\n';
|
||||
<< " out of " << params.validation_count << '\n'
|
||||
<< "INFO (learn): This either means that less than 1% of the validation data passed the filter"
|
||||
<< " or the file is empty\n";
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -1235,6 +1273,7 @@ namespace Learner
|
||||
else if (option == "verbose") params.verbose = true;
|
||||
else if (option == "assume_quiet") params.assume_quiet = true;
|
||||
else if (option == "smart_fen_skipping") params.smart_fen_skipping = true;
|
||||
else if (option == "smart_fen_skipping_for_validation") params.smart_fen_skipping_for_validation = true;
|
||||
else
|
||||
{
|
||||
out << "INFO: Unknown option: " << option << ". Ignoring.\n";
|
||||
@@ -1306,6 +1345,9 @@ namespace Learner
|
||||
out << " - sfen_read_size : " << params.sfen_read_size << endl;
|
||||
out << " - thread_buffer_size : " << params.thread_buffer_size << endl;
|
||||
|
||||
out << " - smart_fen_skipping : " << params.smart_fen_skipping << endl;
|
||||
out << " - smart_fen_skipping_val : " << params.smart_fen_skipping_for_validation << endl;
|
||||
|
||||
out << " - seed : " << params.seed << endl;
|
||||
out << " - verbose : " << (params.verbose ? "true" : "false") << endl;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user