mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-26 03:56:50 +08:00
Detect constant expressions in autograd and return 0 grad early.
This commit is contained in:
@@ -120,6 +120,9 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
|
||||
template <typename T, typename Tuple>
|
||||
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
|
||||
|
||||
template <typename... Ts>
|
||||
constexpr bool AreAllConstantV = (std::remove_reference_t<Ts>::is_constant && ...);
|
||||
}
|
||||
|
||||
template <typename T, typename ChildT>
|
||||
@@ -167,16 +170,23 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
const auto call_id = std::get<Detail::CallId>(args);
|
||||
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
||||
if constexpr (ChildT::is_constant)
|
||||
{
|
||||
grad_cache_call_id = call_id;
|
||||
grad_cache = this_->calculate_grad(args);
|
||||
return T(0.0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
return *grad_cache;
|
||||
const auto call_id = std::get<Detail::CallId>(args);
|
||||
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
||||
{
|
||||
grad_cache_call_id = call_id;
|
||||
grad_cache = this_->calculate_grad(args);
|
||||
}
|
||||
|
||||
return *grad_cache;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... ArgsTs,
|
||||
@@ -199,6 +209,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = false;
|
||||
|
||||
constexpr VariableParameter()
|
||||
{
|
||||
@@ -222,6 +234,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = true;
|
||||
|
||||
constexpr ConstantParameter()
|
||||
{
|
||||
}
|
||||
@@ -244,6 +258,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = true;
|
||||
|
||||
constexpr Constant(T x) :
|
||||
m_x(std::move(x))
|
||||
{
|
||||
@@ -270,6 +286,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||
|
||||
constexpr Sum(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
@@ -316,6 +334,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||
|
||||
constexpr Difference(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
@@ -362,6 +382,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
|
||||
|
||||
constexpr Product(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
@@ -408,6 +430,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||
|
||||
constexpr explicit Negation(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
@@ -440,6 +464,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||
|
||||
constexpr explicit Sigmoid(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
@@ -482,6 +508,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||
|
||||
constexpr explicit Pow(ArgT&& x, Id<T> exponent) :
|
||||
m_x(std::forward<ArgT>(x)),
|
||||
m_exponent(std::move(exponent))
|
||||
@@ -516,6 +544,8 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
|
||||
|
||||
constexpr explicit Log(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user