Identify a single evalation chain by ID in autograd to prevent cache reuse for subsequent evaluations of the same expression tree.

This commit is contained in:
Tomasz Sobczyk
2020-11-30 14:01:31 +01:00
committed by nodchip
parent cb812c742c
commit 8adf00ae6e

View File

@@ -8,6 +8,7 @@
#include <tuple>
#include <optional>
#include <algorithm>
#include <cstdint>
namespace Learner
{
@@ -76,46 +77,122 @@ namespace Learner::Autograd::UnivariateStatic
const std::remove_reference_t<T>&
>;
namespace Detail
{
using CallIdType = std::uint32_t;
struct CallId
{
CallIdType call_id{};
constexpr CallId() :
call_id(0)
{
}
constexpr CallId(CallIdType id) :
call_id(id)
{
}
[[nodiscard]] bool operator==(CallId rhs) const noexcept
{
return call_id == rhs.call_id;
}
[[nodiscard]] bool operator!=(CallId rhs) const noexcept
{
return call_id != rhs.call_id;
}
};
[[nodiscard]] inline CallId next_call_id()
{
static thread_local CallIdType s_call_id = 0;
return CallId{ s_call_id++ };
}
template <typename T, typename Tuple>
struct TupleContains;
template <typename T, typename... Us>
struct TupleContains<T, std::tuple<Us...>> : std::disjunction<std::is_same<T, Us>...> {};
template <typename T, typename Tuple>
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
}
template <typename T, typename ChildT>
struct Evaluable
{
constexpr Evaluable() = default;
// We append a unique call id so that we can invalidate the cache when
// the next computation starts. A single evaluation should see
// the same call_id at every node.
template <typename... ArgsTs>
[[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const
{
return ValueWithGrad<T>{ value(args), grad(args) };
const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return ValueWithGrad<T>{ value(new_args), grad(new_args) };
}
template <typename... ArgsTs>
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const
{
const ChildT* this_ = static_cast<const ChildT*>(this);
if (!value_cache.has_value())
const auto call_id = std::get<Detail::CallId>(args);
if (!value_cache.has_value() || value_cache_call_id != call_id)
{
value_cache_call_id = call_id;
value_cache = this_->calculate_value(args);
}
return *value_cache;
}
template <typename... ArgsTs>
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args, ...) const
{
const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return value(new_args);
}
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
{
const ChildT* this_ = static_cast<const ChildT*>(this);
if (!grad_cache.has_value())
const auto call_id = std::get<Detail::CallId>(args);
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
{
grad_cache_call_id = call_id;
grad_cache = this_->calculate_grad(args);
}
return *grad_cache;
}
template <typename... ArgsTs,
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args, ...) const
{
const auto call_id = Detail::next_call_id();
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
return grad(new_args);
}
private:
mutable std::optional<T> value_cache;
mutable std::optional<T> grad_cache;
mutable Detail::CallId value_cache_call_id{};
mutable Detail::CallId grad_cache_call_id{};
};
template <typename T, int I>