From bbe338b9fcab3ee7f071303ca35956ad667cc6b2 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Thu, 25 Mar 2021 14:00:00 +0100 Subject: [PATCH] Add random move accuracy for comparison. --- src/learn/learn.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index e17537ff..cf19bcc2 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -616,7 +616,8 @@ namespace Learner const PSVector& psv, Loss& test_loss_sum, atomic& sum_norm, - atomic& move_accord_count + atomic& move_accord_count, + atomic& sum_one_over_move_count ); bool has_depth1_move_agreement(Position& pos, Move pvmove); @@ -931,6 +932,12 @@ namespace Learner // search matches the pv first move of search(1). atomic move_accord_count{0}; + // If there is 10 legal moves then 0.1 will be added. + // This happens for each position tested. + // Effectively at the end we have the random move accuracy + // multiplied by the number of positions, which is psv.size() + atomic sum_one_over_move_count{0.0}; + auto mainThread = Threads.main(); mainThread->execute_with_worker([&out](auto& th){ auto& pos = th.rootPos; @@ -949,7 +956,8 @@ namespace Learner psv, test_loss_sum, sum_norm, - move_accord_count + move_accord_count, + sum_one_over_move_count ); }); Threads.wait_for_workers_finished(); @@ -968,6 +976,7 @@ namespace Learner out << " - norm = " << sum_norm << endl; out << " - move accuracy = " << (move_accord_count * 100.0 / psv.size()) << "%" << endl; + out << " - random move accuracy = " << (sum_one_over_move_count * 100.0 / psv.size()) << "%" << endl; } else { @@ -983,10 +992,12 @@ namespace Learner const PSVector& psv, Loss& test_loss_sum, atomic& sum_norm, - atomic& move_accord_count + atomic& move_accord_count, + atomic& sum_one_over_move_count ) { Loss local_loss_sum{}; + double local_sum_one_over_move_count = 0.0; auto& pos = th.rootPos; for(;;) @@ -1022,8 +1033,11 @@ namespace Learner // Threat all moves with equal scores as first. This is up to move ordering. if (has_depth1_move_agreement(pos, (Move)ps.move)) move_accord_count.fetch_add(1, std::memory_order_relaxed); + + local_sum_one_over_move_count += 1.0 / static_cast(MoveList(pos).size()); } + sum_one_over_move_count += local_sum_one_over_move_count; test_loss_sum += local_loss_sum; }