From cbd973fdaaec0685717441a0c5418a95f5527acc Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Mon, 30 Nov 2020 16:50:51 +0100 Subject: [PATCH] Detect constant expressions in autograd and return 0 grad early. --- src/learn/autograd.h | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/learn/autograd.h b/src/learn/autograd.h index 7006121a..45bee469 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -120,6 +120,9 @@ namespace Learner::Autograd::UnivariateStatic template constexpr bool TupleContainsV = TupleContains::value; + + template + constexpr bool AreAllConstantV = (std::remove_reference_t::is_constant && ...); } template @@ -167,16 +170,23 @@ namespace Learner::Autograd::UnivariateStatic typename SFINAE = std::enable_if_t>>> [[nodiscard]] auto grad(const std::tuple& args) const { - const ChildT* this_ = static_cast(this); - - const auto call_id = std::get(args); - if (!grad_cache.has_value() || grad_cache_call_id != call_id) + if constexpr (ChildT::is_constant) { - grad_cache_call_id = call_id; - grad_cache = this_->calculate_grad(args); + return T(0.0); } + else + { + const ChildT* this_ = static_cast(this); - return *grad_cache; + const auto call_id = std::get(args); + if (!grad_cache.has_value() || grad_cache_call_id != call_id) + { + grad_cache_call_id = call_id; + grad_cache = this_->calculate_grad(args); + } + + return *grad_cache; + } } template > { using ValueType = T; + + static constexpr bool is_constant = false; constexpr VariableParameter() { @@ -222,6 +234,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = true; + constexpr ConstantParameter() { } @@ -244,6 +258,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = true; + constexpr Constant(T x) : m_x(std::move(x)) { @@ -270,6 +286,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr Sum(LhsT&& lhs, RhsT&& rhs) : m_lhs(std::forward(lhs)), m_rhs(std::forward(rhs)) @@ -316,6 +334,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr Difference(LhsT&& lhs, RhsT&& rhs) : m_lhs(std::forward(lhs)), m_rhs(std::forward(rhs)) @@ -362,6 +382,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr Product(LhsT&& lhs, RhsT&& rhs) : m_lhs(std::forward(lhs)), m_rhs(std::forward(rhs)) @@ -408,6 +430,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr explicit Negation(ArgT&& x) : m_x(std::forward(x)) { @@ -440,6 +464,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr explicit Sigmoid(ArgT&& x) : m_x(std::forward(x)) { @@ -482,6 +508,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr explicit Pow(ArgT&& x, Id exponent) : m_x(std::forward(x)), m_exponent(std::move(exponent)) @@ -516,6 +544,8 @@ namespace Learner::Autograd::UnivariateStatic { using ValueType = T; + static constexpr bool is_constant = Detail::AreAllConstantV; + constexpr explicit Log(ArgT&& x) : m_x(std::forward(x)) {