pass shuffle flag in the constructor

This commit is contained in:
Tomasz Sobczyk
2020-10-23 22:19:50 +02:00
committed by nodchip
parent 31f94a18b3
commit 8fb208598b
2 changed files with 5 additions and 14 deletions

View File

@@ -383,11 +383,12 @@ namespace Learner
LearnerThink( LearnerThink(
const std::vector<std::string>& filenames, const std::vector<std::string>& filenames,
bool shuffle,
uint64_t thread_num, uint64_t thread_num,
const std::string& seed const std::string& seed
) : ) :
prng(seed), prng(seed),
sr(filenames, SfenReaderMode::Cyclic, thread_num, std::to_string(prng.next_random_seed())), sr(filenames, shuffle, SfenReaderMode::Cyclic, thread_num, std::to_string(prng.next_random_seed())),
learn_loss_sum{} learn_loss_sum{}
{ {
save_only_once = false; save_only_once = false;
@@ -403,11 +404,6 @@ namespace Learner
total_done = 0; total_done = 0;
} }
void set_do_shuffle(bool v)
{
sr.set_do_shuffle(v);
}
void learn(uint64_t epochs); void learn(uint64_t epochs);
@@ -1150,7 +1146,7 @@ namespace Learner
Eval::NNUE::set_batch_size(nn_batch_size); Eval::NNUE::set_batch_size(nn_batch_size);
Eval::NNUE::set_options(nn_options); Eval::NNUE::set_options(nn_options);
LearnerThink learn_think(filenames, thread_num, seed); LearnerThink learn_think(filenames, !no_shuffle, thread_num, seed);
if (newbob_decay != 1.0 && !Options["SkipLoadingEval"]) { if (newbob_decay != 1.0 && !Options["SkipLoadingEval"]) {
// Save the current net to [EvalSaveDir]\original. // Save the current net to [EvalSaveDir]\original.
@@ -1165,7 +1161,6 @@ namespace Learner
// Reflect other option settings. // Reflect other option settings.
learn_think.eval_limit = eval_limit; learn_think.eval_limit = eval_limit;
learn_think.save_only_once = save_only_once; learn_think.save_only_once = save_only_once;
learn_think.set_do_shuffle(!no_shuffle);
learn_think.reduction_gameply = reduction_gameply; learn_think.reduction_gameply = reduction_gameply;
learn_think.newbob_decay = newbob_decay; learn_think.newbob_decay = newbob_decay;

View File

@@ -40,6 +40,7 @@ namespace Learner{
// Because it always the same integers on MinGW. // Because it always the same integers on MinGW.
SfenReader( SfenReader(
const std::vector<std::string>& filenames_, const std::vector<std::string>& filenames_,
bool do_shuffle,
SfenReaderMode mode_, SfenReaderMode mode_,
int thread_num, int thread_num,
const std::string& seed const std::string& seed
@@ -51,7 +52,7 @@ namespace Learner{
packed_sfens.resize(thread_num); packed_sfens.resize(thread_num);
total_read = 0; total_read = 0;
end_of_files = false; end_of_files = false;
shuffle = true; shuffle = do_shuffle;
stop_flag = false; stop_flag = false;
file_worker_thread = std::thread([&] { file_worker_thread = std::thread([&] {
@@ -312,11 +313,6 @@ namespace Learner{
} }
} }
void set_do_shuffle(bool v)
{
shuffle = v;
}
protected: protected:
// worker thread reading file in background // worker thread reading file in background