mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
Add 'validation_count' option for 'learn' that specifies how many positions to use for validation
This commit is contained in:
@@ -80,6 +80,8 @@ Currently the following options are available:
|
||||
|
||||
`validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy)
|
||||
|
||||
`validation_count` - the number of positions to use for validation. Default: 2000.
|
||||
|
||||
`sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M)
|
||||
|
||||
`thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000
|
||||
|
||||
@@ -482,6 +482,12 @@ namespace Learner
|
||||
// Mini batch size size. Be sure to set it on the side that uses this class.
|
||||
uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE;
|
||||
|
||||
// Number of phases used for calculation such as mse
|
||||
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
|
||||
// Since search() is performed with depth = 1 in calculation of
|
||||
// move match rate, simple comparison is not possible...
|
||||
uint64_t validation_count = 2000;
|
||||
|
||||
// Option to exclude early stage from learning
|
||||
int reduction_gameply = 1;
|
||||
|
||||
@@ -550,16 +556,10 @@ namespace Learner
|
||||
}
|
||||
};
|
||||
|
||||
// Number of phases used for calculation such as mse
|
||||
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
|
||||
// Since search() is performed with depth = 1 in calculation of
|
||||
// move match rate, simple comparison is not possible...
|
||||
static constexpr uint64_t sfen_for_mse_size = 2000;
|
||||
|
||||
LearnerThink(const Params& prm) :
|
||||
params(prm),
|
||||
prng(prm.seed),
|
||||
sr(
|
||||
train_sr(
|
||||
prm.filenames,
|
||||
prm.shuffle,
|
||||
SfenReaderMode::Cyclic,
|
||||
@@ -567,6 +567,14 @@ namespace Learner
|
||||
std::to_string(prng.next_random_seed()),
|
||||
prm.sfen_read_size,
|
||||
prm.thread_buffer_size),
|
||||
validation_sr(
|
||||
prm.validation_set_file_name.empty() ? prm.filenames : std::vector<std::string>{ prm.validation_set_file_name },
|
||||
prm.shuffle,
|
||||
SfenReaderMode::Cyclic,
|
||||
1,
|
||||
std::to_string(prng.next_random_seed()),
|
||||
prm.sfen_read_size,
|
||||
prm.thread_buffer_size),
|
||||
learn_loss_sum{}
|
||||
{
|
||||
save_count = 0;
|
||||
@@ -612,7 +620,8 @@ namespace Learner
|
||||
PRNG prng;
|
||||
|
||||
// sfen reader
|
||||
SfenReader sr;
|
||||
SfenReader train_sr;
|
||||
SfenReader validation_sr;
|
||||
|
||||
uint64_t save_count;
|
||||
uint64_t loss_output_count;
|
||||
@@ -666,28 +675,26 @@ namespace Learner
|
||||
|
||||
Eval::NNUE::verify_any_net_loaded();
|
||||
|
||||
const PSVector sfen_for_mse =
|
||||
params.validation_set_file_name.empty()
|
||||
? sr.read_for_mse(sfen_for_mse_size)
|
||||
: sr.read_validation_set(
|
||||
params.validation_set_file_name,
|
||||
const PSVector validation_data =
|
||||
validation_sr.read_some(
|
||||
params.validation_count,
|
||||
params.eval_limit,
|
||||
params.use_draw_games_in_validation);
|
||||
params.use_draw_games_in_validation
|
||||
);
|
||||
|
||||
if (params.validation_set_file_name.empty()
|
||||
&& sfen_for_mse.size() != sfen_for_mse_size)
|
||||
if (validation_data.size() != params.validation_count)
|
||||
{
|
||||
auto out = sync_region_cout.new_region();
|
||||
out
|
||||
<< "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size()
|
||||
<< " out of " << sfen_for_mse_size << '\n';
|
||||
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
|
||||
<< " out of " << params.validation_count << '\n';
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (params.newbob_decay != 1.0) {
|
||||
|
||||
calc_loss(sfen_for_mse, 0);
|
||||
calc_loss(validation_data, 0);
|
||||
|
||||
best_loss = latest_loss_sum / latest_loss_count;
|
||||
latest_loss_sum = 0.0;
|
||||
@@ -714,7 +721,7 @@ namespace Learner
|
||||
if (stop_flag)
|
||||
break;
|
||||
|
||||
update_weights(sfen_for_mse, epoch);
|
||||
update_weights(validation_data, epoch);
|
||||
|
||||
if (stop_flag)
|
||||
break;
|
||||
@@ -742,7 +749,7 @@ namespace Learner
|
||||
|
||||
RETRY_READ:;
|
||||
|
||||
if (!sr.read_to_thread_buffer(thread_id, ps))
|
||||
if (!train_sr.read_to_thread_buffer(thread_id, ps))
|
||||
{
|
||||
// If we ran out of data we stop completely
|
||||
// because there's nothing left to do.
|
||||
@@ -1146,6 +1153,7 @@ namespace Learner
|
||||
is >> filename;
|
||||
params.filenames.push_back(filename);
|
||||
}
|
||||
else if (option == "validation_count") is >> params.validation_count;
|
||||
|
||||
// Specify the number of loops
|
||||
else if (option == "epochs") is >> epochs;
|
||||
@@ -1260,6 +1268,7 @@ namespace Learner
|
||||
out << " - validation set : " << params.validation_set_file_name << endl;
|
||||
}
|
||||
|
||||
out << " - validation count : " << params.validation_count << endl;
|
||||
out << " - epochs : " << epochs << endl;
|
||||
out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl;
|
||||
out << " - eval_limit : " << params.eval_limit << endl;
|
||||
|
||||
@@ -73,10 +73,10 @@ namespace Learner{
|
||||
}
|
||||
|
||||
// Load the phase for calculation such as mse.
|
||||
PSVector read_for_mse(uint64_t count)
|
||||
PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games)
|
||||
{
|
||||
PSVector sfen_for_mse;
|
||||
sfen_for_mse.reserve(count);
|
||||
PSVector psv;
|
||||
psv.reserve(count);
|
||||
|
||||
for (uint64_t i = 0; i < count; ++i)
|
||||
{
|
||||
@@ -84,43 +84,19 @@ namespace Learner{
|
||||
if (!read_to_thread_buffer(0, ps))
|
||||
{
|
||||
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
|
||||
return sfen_for_mse;
|
||||
return psv;
|
||||
}
|
||||
|
||||
sfen_for_mse.push_back(ps);
|
||||
if (eval_limit < abs(ps.score))
|
||||
continue;
|
||||
|
||||
if (!use_draw_games && ps.game_result == 0)
|
||||
continue;
|
||||
|
||||
psv.push_back(ps);
|
||||
}
|
||||
|
||||
return sfen_for_mse;
|
||||
}
|
||||
|
||||
PSVector read_validation_set(const std::string& file_name, int eval_limit, bool use_draw_games)
|
||||
{
|
||||
PSVector sfen_for_mse;
|
||||
|
||||
auto input = open_sfen_input_file(file_name);
|
||||
|
||||
while(!input->eof())
|
||||
{
|
||||
std::optional<PackedSfenValue> p_opt = input->next();
|
||||
if (p_opt.has_value())
|
||||
{
|
||||
auto& p = *p_opt;
|
||||
|
||||
if (eval_limit < abs(p.score))
|
||||
continue;
|
||||
|
||||
if (!use_draw_games && p.game_result == 0)
|
||||
continue;
|
||||
|
||||
sfen_for_mse.push_back(p);
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return sfen_for_mse;
|
||||
return psv;
|
||||
}
|
||||
|
||||
// [ASYNC] Thread returns one aspect. Otherwise returns false.
|
||||
|
||||
Reference in New Issue
Block a user