Add 'validation_count' option for 'learn' that specifies how many positions to use for validation

This commit is contained in:
Tomasz Sobczyk
2020-12-19 23:46:18 +01:00
committed by nodchip
parent a7378f3249
commit f56613ebf6
3 changed files with 44 additions and 57 deletions

View File

@@ -80,6 +80,8 @@ Currently the following options are available:
`validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy)
`validation_count` - the number of positions to use for validation. Default: 2000.
`sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M)
`thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000

View File

@@ -482,6 +482,12 @@ namespace Learner
// Mini batch size size. Be sure to set it on the side that uses this class.
uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE;
// Number of phases used for calculation such as mse
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
// Since search() is performed with depth = 1 in calculation of
// move match rate, simple comparison is not possible...
uint64_t validation_count = 2000;
// Option to exclude early stage from learning
int reduction_gameply = 1;
@@ -550,16 +556,10 @@ namespace Learner
}
};
// Number of phases used for calculation such as mse
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
// Since search() is performed with depth = 1 in calculation of
// move match rate, simple comparison is not possible...
static constexpr uint64_t sfen_for_mse_size = 2000;
LearnerThink(const Params& prm) :
params(prm),
prng(prm.seed),
sr(
train_sr(
prm.filenames,
prm.shuffle,
SfenReaderMode::Cyclic,
@@ -567,6 +567,14 @@ namespace Learner
std::to_string(prng.next_random_seed()),
prm.sfen_read_size,
prm.thread_buffer_size),
validation_sr(
prm.validation_set_file_name.empty() ? prm.filenames : std::vector<std::string>{ prm.validation_set_file_name },
prm.shuffle,
SfenReaderMode::Cyclic,
1,
std::to_string(prng.next_random_seed()),
prm.sfen_read_size,
prm.thread_buffer_size),
learn_loss_sum{}
{
save_count = 0;
@@ -612,7 +620,8 @@ namespace Learner
PRNG prng;
// sfen reader
SfenReader sr;
SfenReader train_sr;
SfenReader validation_sr;
uint64_t save_count;
uint64_t loss_output_count;
@@ -666,28 +675,26 @@ namespace Learner
Eval::NNUE::verify_any_net_loaded();
const PSVector sfen_for_mse =
params.validation_set_file_name.empty()
? sr.read_for_mse(sfen_for_mse_size)
: sr.read_validation_set(
params.validation_set_file_name,
const PSVector validation_data =
validation_sr.read_some(
params.validation_count,
params.eval_limit,
params.use_draw_games_in_validation);
params.use_draw_games_in_validation
);
if (params.validation_set_file_name.empty()
&& sfen_for_mse.size() != sfen_for_mse_size)
if (validation_data.size() != params.validation_count)
{
auto out = sync_region_cout.new_region();
out
<< "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size()
<< " out of " << sfen_for_mse_size << '\n';
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << params.validation_count << '\n';
return;
}
if (params.newbob_decay != 1.0) {
calc_loss(sfen_for_mse, 0);
calc_loss(validation_data, 0);
best_loss = latest_loss_sum / latest_loss_count;
latest_loss_sum = 0.0;
@@ -714,7 +721,7 @@ namespace Learner
if (stop_flag)
break;
update_weights(sfen_for_mse, epoch);
update_weights(validation_data, epoch);
if (stop_flag)
break;
@@ -742,7 +749,7 @@ namespace Learner
RETRY_READ:;
if (!sr.read_to_thread_buffer(thread_id, ps))
if (!train_sr.read_to_thread_buffer(thread_id, ps))
{
// If we ran out of data we stop completely
// because there's nothing left to do.
@@ -1146,6 +1153,7 @@ namespace Learner
is >> filename;
params.filenames.push_back(filename);
}
else if (option == "validation_count") is >> params.validation_count;
// Specify the number of loops
else if (option == "epochs") is >> epochs;
@@ -1260,6 +1268,7 @@ namespace Learner
out << " - validation set : " << params.validation_set_file_name << endl;
}
out << " - validation count : " << params.validation_count << endl;
out << " - epochs : " << epochs << endl;
out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl;
out << " - eval_limit : " << params.eval_limit << endl;

View File

@@ -73,10 +73,10 @@ namespace Learner{
}
// Load the phase for calculation such as mse.
PSVector read_for_mse(uint64_t count)
PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games)
{
PSVector sfen_for_mse;
sfen_for_mse.reserve(count);
PSVector psv;
psv.reserve(count);
for (uint64_t i = 0; i < count; ++i)
{
@@ -84,43 +84,19 @@ namespace Learner{
if (!read_to_thread_buffer(0, ps))
{
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
return sfen_for_mse;
return psv;
}
sfen_for_mse.push_back(ps);
if (eval_limit < abs(ps.score))
continue;
if (!use_draw_games && ps.game_result == 0)
continue;
psv.push_back(ps);
}
return sfen_for_mse;
}
PSVector read_validation_set(const std::string& file_name, int eval_limit, bool use_draw_games)
{
PSVector sfen_for_mse;
auto input = open_sfen_input_file(file_name);
while(!input->eof())
{
std::optional<PackedSfenValue> p_opt = input->next();
if (p_opt.has_value())
{
auto& p = *p_opt;
if (eval_limit < abs(p.score))
continue;
if (!use_draw_games && p.game_result == 0)
continue;
sfen_for_mse.push_back(p);
}
else
{
break;
}
}
return sfen_for_mse;
return psv;
}
// [ASYNC] Thread returns one aspect. Otherwise returns false.