diff --git a/src/learn/autograd.h b/src/learn/autograd.h index 45bee469..4383dfab 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -209,7 +209,7 @@ namespace Learner::Autograd::UnivariateStatic struct VariableParameter : Evaluable> { using ValueType = T; - + static constexpr bool is_constant = false; constexpr VariableParameter() @@ -281,6 +281,36 @@ namespace Learner::Autograd::UnivariateStatic T m_x; }; + // The "constant" may change between executions, but is assumed to be + // constant during a single evaluation. + template + struct ConstantRef : Evaluable> + { + using ValueType = T; + + static constexpr bool is_constant = true; + + constexpr ConstantRef(const T& x) : + m_x(x) + { + } + + template + [[nodiscard]] T calculate_value(const std::tuple&) const + { + return m_x; + } + + template + [[nodiscard]] T calculate_grad(const std::tuple&) const + { + return T(0.0); + } + + private: + const T& m_x; + }; + template ::ValueType> struct Sum : Evaluable> { diff --git a/src/learn/learn.cpp b/src/learn/learn.cpp index dd893d9d..8e32836b 100644 --- a/src/learn/learn.cpp +++ b/src/learn/learn.cpp @@ -220,6 +220,16 @@ namespace Learner return loss_; } + template + static auto& expected_perf_(ValueT&& v_) + { + using namespace Learner::Autograd::UnivariateStatic; + + static thread_local auto perf_ = sigmoid(std::forward(v_) * ConstantRef(winning_probability_coefficient)); + + return perf_; + } + static ValueWithGrad get_loss(Value shallow, Value teacher_signal, int result, int ply) { using namespace Learner::Autograd::UnivariateStatic; @@ -238,18 +248,17 @@ namespace Learner auto loss_ = pow(q_ - p_, 2.0) * (1.0 / (2400.0 * 2.0 * 600.0)); */ - static thread_local auto q_ = sigmoid(VariableParameter{} * ConstantParameter{}); - static thread_local auto p_ = sigmoid(ConstantParameter{} * ConstantParameter{}); + static thread_local auto q_ = expected_perf_(VariableParameter{}); + static thread_local auto p_ = expected_perf_(ConstantParameter{}); static thread_local auto t_ = (ConstantParameter{} + 1.0) * 0.5; static thread_local auto lambda_ = ConstantParameter{}; static thread_local auto loss_ = cross_entropy_(q_, p_, t_, lambda_); auto args = std::tuple( - (double)shallow, - (double)teacher_signal, - (double)result, - calculate_lambda(teacher_signal), - winning_probability_coefficient + (double)shallow, + (double)teacher_signal, + (double)result, + calculate_lambda(teacher_signal) ); return loss_.eval(args).clamp_grad(max_grad);