Add random move accuracy for comparison.

This commit is contained in:
Tomasz Sobczyk
2021-03-25 14:00:00 +01:00
committed by nodchip
parent 5fdb48a7cb
commit bbe338b9fc

View File

@@ -616,7 +616,8 @@ namespace Learner
const PSVector& psv,
Loss& test_loss_sum,
atomic<double>& sum_norm,
atomic<int>& move_accord_count
atomic<int>& move_accord_count,
atomic<double>& sum_one_over_move_count
);
bool has_depth1_move_agreement(Position& pos, Move pvmove);
@@ -931,6 +932,12 @@ namespace Learner
// search matches the pv first move of search(1).
atomic<int> move_accord_count{0};
// If there is 10 legal moves then 0.1 will be added.
// This happens for each position tested.
// Effectively at the end we have the random move accuracy
// multiplied by the number of positions, which is psv.size()
atomic<double> sum_one_over_move_count{0.0};
auto mainThread = Threads.main();
mainThread->execute_with_worker([&out](auto& th){
auto& pos = th.rootPos;
@@ -949,7 +956,8 @@ namespace Learner
psv,
test_loss_sum,
sum_norm,
move_accord_count
move_accord_count,
sum_one_over_move_count
);
});
Threads.wait_for_workers_finished();
@@ -968,6 +976,7 @@ namespace Learner
out << " - norm = " << sum_norm << endl;
out << " - move accuracy = " << (move_accord_count * 100.0 / psv.size()) << "%" << endl;
out << " - random move accuracy = " << (sum_one_over_move_count * 100.0 / psv.size()) << "%" << endl;
}
else
{
@@ -983,10 +992,12 @@ namespace Learner
const PSVector& psv,
Loss& test_loss_sum,
atomic<double>& sum_norm,
atomic<int>& move_accord_count
atomic<int>& move_accord_count,
atomic<double>& sum_one_over_move_count
)
{
Loss local_loss_sum{};
double local_sum_one_over_move_count = 0.0;
auto& pos = th.rootPos;
for(;;)
@@ -1022,8 +1033,11 @@ namespace Learner
// Threat all moves with equal scores as first. This is up to move ordering.
if (has_depth1_move_agreement(pos, (Move)ps.move))
move_accord_count.fetch_add(1, std::memory_order_relaxed);
local_sum_one_over_move_count += 1.0 / static_cast<double>(MoveList<LEGAL>(pos).size());
}
sum_one_over_move_count += local_sum_one_over_move_count;
test_loss_sum += local_loss_sum;
}