Synchronize printed info regions in the learner and sfen reader.

This commit is contained in:
Tomasz Sobczyk
2020-10-24 13:34:10 +02:00
committed by nodchip
parent d824bd8ec5
commit 4b72658409
5 changed files with 66 additions and 45 deletions

View File

@@ -138,7 +138,8 @@ namespace Learner
count = 0.0;
}
void print(const std::string& prefix, ostream& s) const
template <typename StreamT>
void print(const std::string& prefix, StreamT& s) const
{
s << "==> " << prefix << "_cross_entropy_eval = " << cross_entropy_eval / count << endl;
s << "==> " << prefix << "_cross_entropy_win = " << cross_entropy_win / count << endl;
@@ -499,8 +500,9 @@ namespace Learner
if (validation_set_file_name.empty()
&& sfen_for_mse.size() != sfen_for_mse_size)
{
cout
<< "Error reading sfen_for_mse. Read " << sfen_for_mse.size()
auto out = sync_region_cout.new_region();
out
<< "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size()
<< " out of " << sfen_for_mse_size << '\n';
return;
@@ -514,7 +516,8 @@ namespace Learner
latest_loss_sum = 0.0;
latest_loss_count = 0;
cout << "initial loss: " << best_loss << endl;
auto out = sync_region_cout.new_region();
out << "INFO (learn): initial loss = " << best_loss << endl;
}
stop_flag = false;
@@ -585,7 +588,8 @@ namespace Learner
if (pos.set_from_packed_sfen(ps.sfen, &si, &th) != 0)
{
// Malformed sfen
cout << "Error! : illigal packed sfen = " << pos.fen() << endl;
auto out = sync_region_cout.new_region();
out << "ERROR: illigal packed sfen = " << pos.fen() << endl;
goto RETRY_READ;
}
@@ -674,14 +678,16 @@ namespace Learner
TT.new_search();
TimePoint elapsed = now() - Search::Limits.startTime + 1;
cout << "\n";
cout << "PROGRESS (calc_loss): " << now_string()
auto out = sync_region_cout.new_region();
out << "\n";
out << "PROGRESS (calc_loss): " << now_string()
<< ", " << total_done << " sfens"
<< ", " << total_done * 1000 / elapsed << " sfens/second"
<< ", epoch " << epoch
<< endl;
cout << "==> learning rate = " << global_learning_rate << endl;
out << "==> learning rate = " << global_learning_rate << endl;
// For calculation of verification data loss
AtomicLoss test_loss_sum{};
@@ -694,11 +700,11 @@ namespace Learner
atomic<int> move_accord_count{0};
auto mainThread = Threads.main();
mainThread->execute_with_worker([](auto& th){
mainThread->execute_with_worker([&out](auto& th){
auto& pos = th.rootPos;
StateInfo si;
pos.set(StartFEN, false, &si, &th);
cout << "==> startpos eval = " << Eval::evaluate(pos) << endl;
out << "==> startpos eval = " << Eval::evaluate(pos) << endl;
});
mainThread->wait_for_worker_finished();
@@ -721,19 +727,19 @@ namespace Learner
if (psv.size() && test_loss_sum.count > 0.0)
{
test_loss_sum.print("test", cout);
test_loss_sum.print("test", out);
if (learn_loss_sum.count > 0.0)
{
learn_loss_sum.print("learn", cout);
learn_loss_sum.print("learn", out);
}
cout << "==> norm = " << sum_norm << endl;
cout << "==> move accuracy = " << (move_accord_count * 100.0 / psv.size()) << "%" << endl;
out << "==> norm = " << sum_norm << endl;
out << "==> move accuracy = " << (move_accord_count * 100.0 / psv.size()) << "%" << endl;
}
else
{
cout << "Error! : psv.size() = " << psv.size() << " , done = " << test_loss_sum.count << endl;
out << "ERROR: psv.size() = " << psv.size() << " , done = " << test_loss_sum.count << endl;
}
learn_loss_sum.reset();