mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 03:26:24 +08:00
Identify a single evalation chain by ID in autograd to prevent cache reuse for subsequent evaluations of the same expression tree.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <tuple>
|
||||
#include <optional>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
@@ -76,46 +77,122 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
const std::remove_reference_t<T>&
|
||||
>;
|
||||
|
||||
namespace Detail
|
||||
{
|
||||
using CallIdType = std::uint32_t;
|
||||
|
||||
struct CallId
|
||||
{
|
||||
CallIdType call_id{};
|
||||
|
||||
constexpr CallId() :
|
||||
call_id(0)
|
||||
{
|
||||
}
|
||||
|
||||
constexpr CallId(CallIdType id) :
|
||||
call_id(id)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator==(CallId rhs) const noexcept
|
||||
{
|
||||
return call_id == rhs.call_id;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator!=(CallId rhs) const noexcept
|
||||
{
|
||||
return call_id != rhs.call_id;
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] inline CallId next_call_id()
|
||||
{
|
||||
static thread_local CallIdType s_call_id = 0;
|
||||
return CallId{ s_call_id++ };
|
||||
}
|
||||
|
||||
template <typename T, typename Tuple>
|
||||
struct TupleContains;
|
||||
|
||||
template <typename T, typename... Us>
|
||||
struct TupleContains<T, std::tuple<Us...>> : std::disjunction<std::is_same<T, Us>...> {};
|
||||
|
||||
template <typename T, typename Tuple>
|
||||
constexpr bool TupleContainsV = TupleContains<T, Tuple>::value;
|
||||
}
|
||||
|
||||
template <typename T, typename ChildT>
|
||||
struct Evaluable
|
||||
{
|
||||
constexpr Evaluable() = default;
|
||||
|
||||
// We append a unique call id so that we can invalidate the cache when
|
||||
// the next computation starts. A single evaluation should see
|
||||
// the same call_id at every node.
|
||||
template <typename... ArgsTs>
|
||||
[[nodiscard]] auto eval(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
return ValueWithGrad<T>{ value(args), grad(args) };
|
||||
const auto call_id = Detail::next_call_id();
|
||||
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||
return ValueWithGrad<T>{ value(new_args), grad(new_args) };
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
template <typename... ArgsTs,
|
||||
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
if (!value_cache.has_value())
|
||||
const auto call_id = std::get<Detail::CallId>(args);
|
||||
if (!value_cache.has_value() || value_cache_call_id != call_id)
|
||||
{
|
||||
value_cache_call_id = call_id;
|
||||
value_cache = this_->calculate_value(args);
|
||||
}
|
||||
|
||||
return *value_cache;
|
||||
}
|
||||
|
||||
template <typename... ArgsTs>
|
||||
template <typename... ArgsTs,
|
||||
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||
[[nodiscard]] auto value(const std::tuple<ArgsTs...>& args, ...) const
|
||||
{
|
||||
const auto call_id = Detail::next_call_id();
|
||||
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||
return value(new_args);
|
||||
}
|
||||
|
||||
template <typename... ArgsTs,
|
||||
typename SFINAE = std::enable_if_t<Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args) const
|
||||
{
|
||||
const ChildT* this_ = static_cast<const ChildT*>(this);
|
||||
|
||||
if (!grad_cache.has_value())
|
||||
const auto call_id = std::get<Detail::CallId>(args);
|
||||
if (!grad_cache.has_value() || grad_cache_call_id != call_id)
|
||||
{
|
||||
grad_cache_call_id = call_id;
|
||||
grad_cache = this_->calculate_grad(args);
|
||||
}
|
||||
|
||||
return *grad_cache;
|
||||
}
|
||||
|
||||
template <typename... ArgsTs,
|
||||
typename SFINAE = std::enable_if_t<!Detail::TupleContainsV<Detail::CallId, std::tuple<ArgsTs...>>>>
|
||||
[[nodiscard]] auto grad(const std::tuple<ArgsTs...>& args, ...) const
|
||||
{
|
||||
const auto call_id = Detail::next_call_id();
|
||||
const auto new_args = std::tuple_cat(args, std::tuple(call_id));
|
||||
return grad(new_args);
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::optional<T> value_cache;
|
||||
mutable std::optional<T> grad_cache;
|
||||
mutable Detail::CallId value_cache_call_id{};
|
||||
mutable Detail::CallId grad_cache_call_id{};
|
||||
};
|
||||
|
||||
template <typename T, int I>
|
||||
|
||||
Reference in New Issue
Block a user