diff --git a/src/learn/autograd.h b/src/learn/autograd.h index a4ad8b7f..2b0eee3a 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace Learner { @@ -62,20 +63,48 @@ namespace Learner::Autograd::UnivariateStatic template using Id = typename Identity::type; - template + template struct Evaluable { template auto eval(const std::tuple& args) const { - using ValueType = typename T::ValueType; - const T* this_ = static_cast(this); - return ValueWithGrad{ this_->value(args), this_->grad(args) }; + return ValueWithGrad{ value(args), grad(args) }; } + + template + auto value(const std::tuple& args) const + { + const ChildT* this_ = static_cast(this); + + if (!value_cache.has_value()) + { + value_cache = this_->calculate_value(args); + } + + return *value_cache; + } + + template + auto grad(const std::tuple& args) const + { + const ChildT* this_ = static_cast(this); + + if (!grad_cache.has_value()) + { + grad_cache = this_->calculate_grad(args); + } + + return *grad_cache; + } + + private: + mutable std::optional value_cache; + mutable std::optional grad_cache; }; template - struct VariableParameter : Evaluable> + struct VariableParameter : Evaluable> { using ValueType = T; @@ -84,20 +113,20 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return std::get(args); } template - T grad(const std::tuple&) const + T calculate_grad(const std::tuple&) const { return T(1.0); } }; template - struct ConstantParameter : Evaluable> + struct ConstantParameter : Evaluable> { using ValueType = T; @@ -106,20 +135,20 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return std::get(args); } template - T grad(const std::tuple&) const + T calculate_grad(const std::tuple&) const { return T(0.0); } }; template - struct Constant : Evaluable> + struct Constant : Evaluable> { using ValueType = T; @@ -129,13 +158,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple&) const + T calculate_value(const std::tuple&) const { return m_x; } template - T grad(const std::tuple&) const + T calculate_grad(const std::tuple&) const { return T(0.0); } @@ -145,7 +174,7 @@ namespace Learner::Autograd::UnivariateStatic }; template - struct Sum : Evaluable> + struct Sum : Evaluable> { using ValueType = T; @@ -156,13 +185,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return m_lhs.value(args) + m_rhs.value(args); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_lhs.grad(args) + m_rhs.grad(args); } @@ -191,7 +220,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Difference : Evaluable> + struct Difference : Evaluable> { using ValueType = T; @@ -202,13 +231,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return m_lhs.value(args) - m_rhs.value(args); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_lhs.grad(args) - m_rhs.grad(args); } @@ -237,7 +266,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Product : Evaluable> + struct Product : Evaluable> { using ValueType = T; @@ -248,13 +277,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return m_lhs.value(args) * m_rhs.value(args); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_lhs.grad(args) * m_rhs.value(args) + m_lhs.value(args) * m_rhs.grad(args); } @@ -283,7 +312,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Negation : Evaluable> + struct Negation : Evaluable> { using ValueType = T; @@ -293,13 +322,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return -m_x.value(args); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return -m_x.grad(args); } @@ -315,7 +344,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Sigmoid : Evaluable> + struct Sigmoid : Evaluable> { using ValueType = T; @@ -325,13 +354,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return value_(m_x.value(args)); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_x.grad(args) * grad_(m_x.value(args)); } @@ -357,7 +386,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Pow : Evaluable> + struct Pow : Evaluable> { using ValueType = T; @@ -368,13 +397,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return std::pow(m_x.value(args), m_exponent); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_exponent * std::pow(m_x.value(args), m_exponent - T(1.0)) * m_x.grad(args); } @@ -391,7 +420,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Log : Evaluable> + struct Log : Evaluable> { using ValueType = T; @@ -401,13 +430,13 @@ namespace Learner::Autograd::UnivariateStatic } template - T value(const std::tuple& args) const + T calculate_value(const std::tuple& args) const { return value_(m_x.value(args)); } template - T grad(const std::tuple& args) const + T calculate_grad(const std::tuple& args) const { return m_x.grad(args) * grad_(m_x.value(args)); }