Detect constant expressions in autograd and return 0 grad early.

This commit is contained in:
Tomasz Sobczyk
2020-11-30 16:50:51 +01:00
committed by nodchip
parent e975889132
commit cbd973fdaa

View File

@@ -120,6 +120,9 @@ namespace Learner::Autograd::UnivariateStatic
template <typename T, typename Tuple>
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
template <typename... Ts>
constexpr bool AreAllConstantV = (std::remove_reference_t<Ts>::is_constant && ...);
}
template <typename T, typename ChildT>
@@ -167,16 +170,23 @@ namespace Learner::Autograd::UnivariateStatic
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
{
const ChildT* this_ = static_cast<const ChildT*>(this);
const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
if constexpr (ChildT::is_constant)
{
grad_cache_call_id = call_id;
grad_cache = this_->calculate_grad(args);
return T(0.0);
}
else
{
const ChildT* this_ = static_cast<const ChildT*>(this);
return *grad_cache;
const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
{
grad_cache_call_id = call_id;
grad_cache = this_->calculate_grad(args);
}
return *grad_cache;
}
}
template <typename... ArgsTs,
@@ -199,6 +209,8 @@ namespace Learner::Autograd::UnivariateStatic
struct VariableParameter : Evaluable<T, VariableParameter<T, I>>
{
using ValueType = T;
static constexpr bool is_constant = false;
constexpr VariableParameter()
{
@@ -222,6 +234,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = true;
constexpr ConstantParameter()
{
}
@@ -244,6 +258,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = true;
constexpr Constant(T x) :
m_x(std::move(x))
{
@@ -270,6 +286,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Sum(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
@@ -316,6 +334,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Difference(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
@@ -362,6 +382,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<LhsT, RhsT>;
constexpr Product(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
@@ -408,6 +430,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Negation(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{
@@ -440,6 +464,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Sigmoid(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{
@@ -482,6 +508,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Pow(ArgT&& x, Id<T> exponent) :
m_x(std::forward<ArgT>(x)),
m_exponent(std::move(exponent))
@@ -516,6 +544,8 @@ namespace Learner::Autograd::UnivariateStatic
{
using ValueType = T;
static constexpr bool is_constant = Detail::AreAllConstantV<ArgT>;
constexpr explicit Log(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{