diff --git a/src/learn/autograd.h b/src/learn/autograd.h index 5c573c0f..7006121a 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace Learner { @@ -76,46 +77,122 @@ namespace Learner::Autograd::UnivariateStatic const std::remove_reference_t& >; + namespace Detail + { + using CallIdType = std::uint32_t; + + struct CallId + { + CallIdType call_id{}; + + constexpr CallId() : + call_id(0) + { + } + + constexpr CallId(CallIdType id) : + call_id(id) + { + } + + [[nodiscard]] bool operator==(CallId rhs) const noexcept + { + return call_id == rhs.call_id; + } + + [[nodiscard]] bool operator!=(CallId rhs) const noexcept + { + return call_id != rhs.call_id; + } + }; + + [[nodiscard]] inline CallId next_call_id() + { + static thread_local CallIdType s_call_id = 0; + return CallId{ s_call_id++ }; + } + + template + struct TupleContains; + + template + struct TupleContains> : std::disjunction...> {}; + + template + constexpr bool TupleContainsV = TupleContains::value; + } + template struct Evaluable { constexpr Evaluable() = default; + // We append a unique call id so that we can invalidate the cache when + // the next computation starts. A single evaluation should see + // the same call_id at every node. template [[nodiscard]] auto eval(const std::tuple& args) const { - return ValueWithGrad{ value(args), grad(args) }; + const auto call_id = Detail::next_call_id(); + const auto new_args = std::tuple_cat(args, std::tuple(call_id)); + return ValueWithGrad{ value(new_args), grad(new_args) }; } - template + template >>> [[nodiscard]] auto value(const std::tuple& args) const { const ChildT* this_ = static_cast(this); - if (!value_cache.has_value()) + const auto call_id = std::get(args); + if (!value_cache.has_value() || value_cache_call_id != call_id) { + value_cache_call_id = call_id; value_cache = this_->calculate_value(args); } return *value_cache; } - template + template >>> + [[nodiscard]] auto value(const std::tuple& args, ...) const + { + const auto call_id = Detail::next_call_id(); + const auto new_args = std::tuple_cat(args, std::tuple(call_id)); + return value(new_args); + } + + template >>> [[nodiscard]] auto grad(const std::tuple& args) const { const ChildT* this_ = static_cast(this); - if (!grad_cache.has_value()) + const auto call_id = std::get(args); + if (!grad_cache.has_value() || grad_cache_call_id != call_id) { + grad_cache_call_id = call_id; grad_cache = this_->calculate_grad(args); } return *grad_cache; } + template >>> + [[nodiscard]] auto grad(const std::tuple& args, ...) const + { + const auto call_id = Detail::next_call_id(); + const auto new_args = std::tuple_cat(args, std::tuple(call_id)); + return grad(new_args); + } + private: mutable std::optional value_cache; mutable std::optional grad_cache; + mutable Detail::CallId value_cache_call_id{}; + mutable Detail::CallId grad_cache_call_id{}; }; template