diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index 0b04d034..af867d42 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -215,21 +215,28 @@ namespace Learner auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0)); */ - const double epsilon = 1e-12; + constexpr double epsilon = 1e-12; - auto q_ = sigmoid(VariableParameter{} * winning_probability_coefficient); - auto p_ = sigmoid(ConstantParameter{} * winning_probability_coefficient); - auto t_ = (ConstantParameter{} + 1.0) * 0.5; - auto lambda_ = ConstantParameter{}; - auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); - auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); - auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); - auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); - auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; - auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; - auto loss_ = result_ - entropy_; + static thread_local auto q_ = sigmoid(VariableParameter{} * ConstantParameter{}); + static thread_local auto p_ = sigmoid(ConstantParameter{} * ConstantParameter{}); + static thread_local auto t_ = (ConstantParameter{} + 1.0) * 0.5; + static thread_local auto lambda_ = ConstantParameter{}; + static thread_local auto teacher_entropy_ = -(p_ * log(p_ + epsilon) + (1.0 - p_) * log(1.0 - p_ + epsilon)); + static thread_local auto outcome_entropy_ = -(t_ * log(t_ + epsilon) + (1.0 - t_) * log(1.0 - t_ + epsilon)); + static thread_local auto teacher_loss_ = -(p_ * log(q_) + (1.0 - p_) * log(1.0 - q_)); + static thread_local auto outcome_loss_ = -(t_ * log(q_) + (1.0 - t_) * log(1.0 - q_)); + static thread_local auto result_ = lambda_ * teacher_loss_ + (1.0 - lambda_) * outcome_loss_; + static thread_local auto entropy_ = lambda_ * teacher_entropy_ + (1.0 - lambda_) * outcome_entropy_; + static thread_local auto loss_ = result_ - entropy_; + + auto args = std::tuple( + (double)shallow, + (double)teacher_signal, + (double)result, + calculate_lambda(teacher_signal), + winning_probability_coefficient + ); - auto args = std::tuple((double)shallow, (double)teacher_signal, (double)result, calculate_lambda(teacher_signal)); return loss_.eval(args).clamp_grad(max_grad); }