mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
More utility in autograd.
This commit is contained in:
@@ -7,6 +7,44 @@
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
template <typename T>
|
||||
struct ValueWithGrad
|
||||
{
|
||||
T value;
|
||||
T grad;
|
||||
|
||||
ValueWithGrad& operator+=(const ValueWithGrad<T>& rhs)
|
||||
{
|
||||
value += rhs.value;
|
||||
grad += rhs.grad;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ValueWithGrad& operator-=(const ValueWithGrad<T>& rhs)
|
||||
{
|
||||
value -= rhs.value;
|
||||
grad -= rhs.grad;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ValueWithGrad& operator*=(T rhs)
|
||||
{
|
||||
value *= rhs;
|
||||
grad *= rhs;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ValueWithGrad& operator/=(T rhs)
|
||||
{
|
||||
value /= rhs;
|
||||
grad /= rhs;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
|
||||
@@ -19,8 +57,20 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
template <typename T>
|
||||
using Id = typename Identity<T>::type;
|
||||
|
||||
template <typename T>
|
||||
struct Evaluable
|
||||
{
|
||||
template <typename... ArgsTs>
|
||||
auto eval(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
using ValueType = typename T::ValueType;
|
||||
const T* this_ = static_cast<const T*>(this);
|
||||
return ValueWithGrad<ValueType>{ this_->value(args), this_->grad(args) };
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int I>
|
||||
struct VariableParameter
|
||||
struct VariableParameter : Evaluable<VariableParameter<T, I>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -42,7 +92,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
};
|
||||
|
||||
template <typename T, int I>
|
||||
struct ConstantParameter
|
||||
struct ConstantParameter : Evaluable<ConstantParameter<T, I>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -64,7 +114,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Constant
|
||||
struct Constant : Evaluable<Constant<T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -90,7 +140,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Sum
|
||||
struct Sum : Evaluable<Sum<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -136,7 +186,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Difference
|
||||
struct Difference : Evaluable<Difference<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -182,7 +232,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Product
|
||||
struct Product : Evaluable<Product<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -228,7 +278,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Sigmoid
|
||||
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -270,7 +320,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Pow
|
||||
struct Pow : Evaluable<Pow<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -304,7 +354,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Log
|
||||
struct Log : Evaluable<Log<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user