diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index af867d42..dd893d9d 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -197,6 +197,29 @@ namespace Learner return lambda; } + template + static auto& cross_entropy_( + ShallowT& q_, + TeacherT& p_, + ResultT& t_, + LambdaT& lambda_ + ) + { + using namespace Learner::Autograd::UnivariateStatic; + + constexpr double epsilon = 1e-12; + + static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); + static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); + static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); + static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); + static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; + static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; + static thread_local auto loss_ = result_ - entropy_; + + return loss_; + } + static ValueWithGrad get_loss(Value shallow, Value teacher_signal, int result, int ply) { using namespace Learner::Autograd::UnivariateStatic; @@ -215,19 +238,11 @@ namespace Learner auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0)); */ - constexpr double epsilon = 1e-12; - static thread_local auto q_ = sigmoid(VariableParameter{} * ConstantParameter{}); static thread_local auto p_ = sigmoid(ConstantParameter{} * ConstantParameter{}); static thread_local auto t_ = (ConstantParameter{} + 1.0) * 0.5; static thread_local auto lambda_ = ConstantParameter{}; - static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); - static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); - static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); - static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); - static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; - static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; - static thread_local auto loss_ = result_ - entropy_; + static thread_local auto loss_ = cross_entropy_(q_, p_, t_, lambda_); auto args = std::tuple( (double)shallow,