More utility in autograd.

This commit is contained in:
Tomasz Sobczyk
2020-11-29 11:33:35 +01:00
committed by nodchip
parent 6ce0245787
commit 541fb8177a

View File

@@ -7,6 +7,44 @@
#include <memory>
#include <tuple>
namespace Learner
{
template <typename T>
struct ValueWithGrad
{
T value;
T grad;
ValueWithGrad& operator+=(const ValueWithGrad<T>& rhs)
{
value += rhs.value;
grad += rhs.grad;
return *this;
}
ValueWithGrad& operator-=(const ValueWithGrad<T>& rhs)
{
value -= rhs.value;
grad -= rhs.grad;
return *this;
}
ValueWithGrad& operator*=(T rhs)
{
value *= rhs;
grad *= rhs;
return *this;
}
ValueWithGrad& operator/=(T rhs)
{
value /= rhs;
grad /= rhs;
return *this;
}
};
}
namespace Learner::Autograd::UnivariateStatic
{
@@ -19,8 +57,20 @@ namespace Learner::Autograd::UnivariateStatic
template <typename T>
using Id = typename Identity<T>::type;
template <typename T>
struct Evaluable
{
template <typename... ArgsTs>
auto eval(const std::tuple<ArgsTs...>& args) const
{
using ValueType = typename T::ValueType;
const T* this_ = static_cast<const T*>(this);
return ValueWithGrad<ValueType>{ this_->value(args), this_->grad(args) };
}
};
template <typename T, int I>
struct VariableParameter
struct VariableParameter : Evaluable<VariableParameter<T, I>>
{
using ValueType = T;
@@ -42,7 +92,7 @@ namespace Learner::Autograd::UnivariateStatic
};
template <typename T, int I>
struct ConstantParameter
struct ConstantParameter : Evaluable<ConstantParameter<T, I>>
{
using ValueType = T;
@@ -64,7 +114,7 @@ namespace Learner::Autograd::UnivariateStatic
};
template <typename T>
struct Constant
struct Constant : Evaluable<Constant<T>>
{
using ValueType = T;
@@ -90,7 +140,7 @@ namespace Learner::Autograd::UnivariateStatic
};
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
struct Sum
struct Sum : Evaluable<Sum<LhsT, RhsT, T>>
{
using ValueType = T;
@@ -136,7 +186,7 @@ namespace Learner::Autograd::UnivariateStatic
}
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
struct Difference
struct Difference : Evaluable<Difference<LhsT, RhsT, T>>
{
using ValueType = T;
@@ -182,7 +232,7 @@ namespace Learner::Autograd::UnivariateStatic
}
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
struct Product
struct Product : Evaluable<Product<LhsT, RhsT, T>>
{
using ValueType = T;
@@ -228,7 +278,7 @@ namespace Learner::Autograd::UnivariateStatic
}
template <typename ArgT, typename T = typename ArgT::ValueType>
struct Sigmoid
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
{
using ValueType = T;
@@ -270,7 +320,7 @@ namespace Learner::Autograd::UnivariateStatic
}
template <typename ArgT, typename T = typename ArgT::ValueType>
struct Pow
struct Pow : Evaluable<Pow<ArgT, T>>
{
using ValueType = T;
@@ -304,7 +354,7 @@ namespace Learner::Autograd::UnivariateStatic
}
template <typename ArgT, typename T = typename ArgT::ValueType>
struct Log
struct Log : Evaluable<Log<ArgT, T>>
{
using ValueType = T;