diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 450a80c6..449542a7 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -310,7 +310,8 @@ namespace Learner return perf; } - [[maybe_unused]] static ValueWithGrad get_loss_noob(Value shallow, Value teacher_signal, int result, int /* ply */) + [[maybe_unused]] static ValueWithGrad get_loss_noob( + Value shallow, Value teacher_signal, int result, int /* ply */) { using namespace Learner::Autograd::UnivariateStatic; @@ -328,7 +329,7 @@ namespace Learner return loss_.eval(args); } - static ValueWithGrad get_loss_cross_entropy(Value shallow, Value teacher_signal, int result, int /* ply */) + static auto& get_loss_cross_entropy_() { using namespace Learner::Autograd::UnivariateStatic; @@ -338,18 +339,45 @@ namespace Learner static thread_local auto lambda_ = ConstantParameter{}; static thread_local auto& loss_ = cross_entropy_(q_, p_, t_, lambda_); - auto args = std::tuple( + return loss_; + } + + static auto get_loss_cross_entropy_args( + Value shallow, Value teacher_signal, int result) + { + return std::tuple( (double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal) ); + } + + static ValueWithGrad get_loss_cross_entropy( + Value shallow, Value teacher_signal, int result, int /* ply */) + { + using namespace Learner::Autograd::UnivariateStatic; + + static thread_local auto& loss_ = get_loss_cross_entropy_(); + + auto args = get_loss_cross_entropy_args(shallow, teacher_signal, result); return loss_.eval(args); } - static ValueWithGrad get_loss_cross_entropy_use_wdl( - Value shallow, Value teacher_signal, int result, int ply) + static ValueWithGrad get_loss_cross_entropy_no_grad( + Value shallow, Value teacher_signal, int result, int /* ply */) + { + using namespace Learner::Autograd::UnivariateStatic; + + static thread_local auto& loss_ = get_loss_cross_entropy_(); + + auto args = get_loss_cross_entropy_args(shallow, teacher_signal, result); + + return { loss_.value(args), 0.0 }; + } + + static auto& get_loss_cross_entropy_use_wdl_() { using namespace Learner::Autograd::UnivariateStatic; @@ -364,7 +392,13 @@ namespace Learner static thread_local auto lambda_ = ConstantParameter{}; static thread_local auto& loss_ = cross_entropy_(q_, p_, t_, lambda_); - auto args = std::tuple( + return loss_; + } + + static auto get_loss_cross_entropy_use_wdl_args( + Value shallow, Value teacher_signal, int result, int ply) + { + return std::tuple( (double)shallow, // This is required because otherwise MSVC crashes :( expected_perf_use_wdl(scale_score(teacher_signal), ply), @@ -372,10 +406,32 @@ namespace Learner calculate_lambda(teacher_signal), (double)std::min(240, ply) ); + } + + static ValueWithGrad get_loss_cross_entropy_use_wdl( + Value shallow, Value teacher_signal, int result, int ply) + { + using namespace Learner::Autograd::UnivariateStatic; + + static thread_local auto& loss_ = get_loss_cross_entropy_use_wdl_(); + + auto args = get_loss_cross_entropy_use_wdl_args(shallow, teacher_signal, result, ply); return loss_.eval(args); } + static ValueWithGrad get_loss_cross_entropy_use_wdl_no_grad( + Value shallow, Value teacher_signal, int result, int ply) + { + using namespace Learner::Autograd::UnivariateStatic; + + static thread_local auto& loss_ = get_loss_cross_entropy_use_wdl_(); + + auto args = get_loss_cross_entropy_use_wdl_args(shallow, teacher_signal, result, ply); + + return { loss_.value(args), 0.0 }; + } + static auto get_loss(Value shallow, Value teacher_signal, int result, int ply) { using namespace Learner::Autograd::UnivariateStatic; @@ -390,7 +446,21 @@ namespace Learner } } - static auto get_loss( + static auto get_loss_no_grad(Value shallow, Value teacher_signal, int result, int ply) + { + using namespace Learner::Autograd::UnivariateStatic; + + if (use_wdl) + { + return get_loss_cross_entropy_use_wdl_no_grad(shallow, teacher_signal, result, ply); + } + else + { + return get_loss_cross_entropy_no_grad(shallow, teacher_signal, result, ply); + } + } + + [[maybe_unused]] static auto get_loss( Value teacher_signal, Value shallow, const PackedSfenValue& psv) @@ -398,6 +468,14 @@ namespace Learner return get_loss(shallow, teacher_signal, psv.game_result, psv.gamePly); } + static auto get_loss_no_grad( + Value teacher_signal, + Value shallow, + const PackedSfenValue& psv) + { + return get_loss_no_grad(shallow, teacher_signal, psv.game_result, psv.gamePly); + } + // Class to generate sfen with multiple threads struct LearnerThink { @@ -828,11 +906,11 @@ namespace Learner if (psv.size() && test_loss_sum.count() > 0) { - test_loss_sum.print("val", out); + test_loss_sum.print_only_loss("val", out); if (learn_loss_sum.count() > 0) { - learn_loss_sum.print("train", out); + learn_loss_sum.print_with_grad("train", out); } out << " - norm = " << sum_norm << endl; @@ -880,7 +958,7 @@ namespace Learner // Evaluation value of deep search const auto deep_value = (Value)ps.score; - const auto loss = get_loss( + const auto loss = get_loss_no_grad( deep_value, shallow_value, ps); diff --git a/src/learn/learn.h b/src/learn/learn.h index 552096b2..842ffad0 100644 --- a/src/learn/learn.h +++ b/src/learn/learn.h @@ -126,12 +126,18 @@ namespace Learner } template - void print(const std::string& prefix, StreamT& s) const + void print_with_grad(const std::string& prefix, StreamT& s) const { s << " - " << prefix << "_loss = " << m_loss.value / (double)m_count << std::endl; s << " - " << prefix << "_grad_norm = " << m_loss.grad / (double)m_count << std::endl; } + template + void print_only_loss(const std::string& prefix, StreamT& s) const + { + s << " - " << prefix << "_loss = " << m_loss.value / (double)m_count << std::endl; + } + private: ValueWithGrad m_loss{ 0.0, 0.0 }; uint64_t m_count{0};