mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 11:06:58 +08:00
Add memoization to the autograd expression evaluator.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <type_traits>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <optional>
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
@@ -62,20 +63,48 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
template <typename T>
|
||||
using Id = typename Identity<T>::type;
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename ChildT>
|
||||
struct Evaluable
|
||||
{
|
||||
template <typename... ArgsTs>
|
||||
auto eval(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
using ValueType = typename T::ValueType;
|
||||
const T* this_ = static_cast<const T*>(this);
|
||||
return ValueWithGrad<ValueType>{ this_->value(args), this_->grad(args) };
|
||||
return ValueWithGrad<T>{ value(args), grad(args) };
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
auto value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
if (!value_cache.has_value())
|
||||
{
|
||||
value_cache = this_->calculate_value(args);
|
||||
}
|
||||
|
||||
return *value_cache;
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
auto grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
if (!grad_cache.has_value())
|
||||
{
|
||||
grad_cache = this_->calculate_grad(args);
|
||||
}
|
||||
|
||||
return *grad_cache;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::optional<T> value_cache;
|
||||
mutable std::optional<T> grad_cache;
|
||||
};
|
||||
|
||||
template <typename T, int I>
|
||||
struct VariableParameter : Evaluable<VariableParameter<T, I>>
|
||||
struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -84,20 +113,20 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return std::get<I>(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>&) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>&) const
|
||||
{
|
||||
return T(1.0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int I>
|
||||
struct ConstantParameter : Evaluable<ConstantParameter<T, I>>
|
||||
struct ConstantParameter : Evaluable<T, ConstantParameter<T, I>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -106,20 +135,20 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return std::get<I>(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>&) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>&) const
|
||||
{
|
||||
return T(0.0);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Constant : Evaluable<Constant<T>>
|
||||
struct Constant : Evaluable<T, Constant<T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -129,13 +158,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>&) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>&) const
|
||||
{
|
||||
return m_x;
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>&) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>&) const
|
||||
{
|
||||
return T(0.0);
|
||||
}
|
||||
@@ -145,7 +174,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Sum : Evaluable<Sum<LhsT, RhsT, T>>
|
||||
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -156,13 +185,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.value(args) + m_rhs.value(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.grad(args) + m_rhs.grad(args);
|
||||
}
|
||||
@@ -191,7 +220,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Difference : Evaluable<Difference<LhsT, RhsT, T>>
|
||||
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -202,13 +231,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.value(args) - m_rhs.value(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.grad(args) - m_rhs.grad(args);
|
||||
}
|
||||
@@ -237,7 +266,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
struct Product : Evaluable<Product<LhsT, RhsT, T>>
|
||||
struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -248,13 +277,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.value(args) * m_rhs.value(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_lhs.grad(args) * m_rhs.value(args) + m_lhs.value(args) * m_rhs.grad(args);
|
||||
}
|
||||
@@ -283,7 +312,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Negation : Evaluable<Negation<ArgT, T>>
|
||||
struct Negation : Evaluable<T, Negation<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -293,13 +322,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return -m_x.value(args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return -m_x.grad(args);
|
||||
}
|
||||
@@ -315,7 +344,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Sigmoid : Evaluable<Sigmoid<ArgT, T>>
|
||||
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -325,13 +354,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return value_(m_x.value(args));
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_x.grad(args) * grad_(m_x.value(args));
|
||||
}
|
||||
@@ -357,7 +386,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Pow : Evaluable<Pow<ArgT, T>>
|
||||
struct Pow : Evaluable<T, Pow<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -368,13 +397,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return std::pow(m_x.value(args), m_exponent);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_exponent * std::pow(m_x.value(args), m_exponent - T(1.0)) * m_x.grad(args);
|
||||
}
|
||||
@@ -391,7 +420,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
struct Log : Evaluable<Log<ArgT, T>>
|
||||
struct Log : Evaluable<T, Log<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
@@ -401,13 +430,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T value(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return value_(m_x.value(args));
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
T grad(const std::tuple<ArgsTs...>& args) const
|
||||
T calculate_grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return m_x.grad(args) * grad_(m_x.value(args));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user