diff --git a/src/learn/autograd.h b/src/learn/autograd.h index 8a4df2ab..0b894cc4 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -7,6 +7,44 @@ #include #include +namespace Learner +{ + template + struct ValueWithGrad + { + T value; + T grad; + + ValueWithGrad& operator+=(const ValueWithGrad& rhs) + { + value += rhs.value; + grad += rhs.grad; + return *this; + } + + ValueWithGrad& operator-=(const ValueWithGrad& rhs) + { + value -= rhs.value; + grad -= rhs.grad; + return *this; + } + + ValueWithGrad& operator*=(T rhs) + { + value *= rhs; + grad *= rhs; + return *this; + } + + ValueWithGrad& operator/=(T rhs) + { + value /= rhs; + grad /= rhs; + return *this; + } + }; +} + namespace Learner::Autograd::UnivariateStatic { @@ -19,8 +57,20 @@ namespace Learner::Autograd::UnivariateStatic template using Id = typename Identity::type; + template + struct Evaluable + { + template + auto eval(const std::tuple& args) const + { + using ValueType = typename T::ValueType; + const T* this_ = static_cast(this); + return ValueWithGrad{ this_->value(args), this_->grad(args) }; + } + }; + template - struct VariableParameter + struct VariableParameter : Evaluable> { using ValueType = T; @@ -42,7 +92,7 @@ namespace Learner::Autograd::UnivariateStatic }; template - struct ConstantParameter + struct ConstantParameter : Evaluable> { using ValueType = T; @@ -64,7 +114,7 @@ namespace Learner::Autograd::UnivariateStatic }; template - struct Constant + struct Constant : Evaluable> { using ValueType = T; @@ -90,7 +140,7 @@ namespace Learner::Autograd::UnivariateStatic }; template - struct Sum + struct Sum : Evaluable> { using ValueType = T; @@ -136,7 +186,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Difference + struct Difference : Evaluable> { using ValueType = T; @@ -182,7 +232,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Product + struct Product : Evaluable> { using ValueType = T; @@ -228,7 +278,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Sigmoid + struct Sigmoid : Evaluable> { using ValueType = T; @@ -270,7 +320,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Pow + struct Pow : Evaluable> { using ValueType = T; @@ -304,7 +354,7 @@ namespace Learner::Autograd::UnivariateStatic } template - struct Log + struct Log : Evaluable> { using ValueType = T;