From e975889132bd2303915a2e2eb587b2633487c358 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 30 Nov 2020 15:21:39 +0100 Subject: [PATCH] Move cross_entropy calculation to a separate function. --- src/learn/learn.cpp | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index af867d42..dd893d9d 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -197,6 +197,29 @@ namespace Learner return lambda; } + template + static auto& cross_entropy_( + ShallowT& q_, + TeacherT& p_, + ResultT& t_, + LambdaT& lambda_ + ) + { + using namespace Learner::Autograd::UnivariateStatic; + + constexpr double epsilon = 1e-12; + + static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); + static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); + static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); + static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); + static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; + static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; + static thread_local auto loss_ = result_ - entropy_; + + return loss_; + } + static ValueWithGrad get_loss(Value shallow, Value teacher_signal, int result, int ply) { using namespace Learner::Autograd::UnivariateStatic; @@ -215,19 +238,11 @@ namespace Learner auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0)); */ - constexpr double epsilon = 1e-12; - static thread_local auto q_ = sigmoid(VariableParameter{} * ConstantParameter{}); static thread_local auto p_ = sigmoid(ConstantParameter{} * ConstantParameter{}); static thread_local auto t_ = (ConstantParameter{} + 1.0) * 0.5; static thread_local auto lambda_ = ConstantParameter{}; - static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); - static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); - static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); - static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); - static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; - static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; - static thread_local auto loss_ = result_ - entropy_; + static thread_local auto loss_ = cross_entropy_(q_, p_, t_, lambda_); auto args = std::tuple( (double)shallow,