mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
When forming an autograd expression only copy parts that are rvalue references, store references to lvalues.
This commit is contained in:
@@ -69,6 +69,13 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
template <typename T>
|
||||
using Id = typename Identity<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using StoreValueOrRef = std::conditional_t<
|
||||
std::is_rvalue_reference_v<T>,
|
||||
std::remove_reference_t<T>,
|
||||
const std::remove_reference_t<T>&
|
||||
>;
|
||||
|
||||
template <typename T, typename ChildT>
|
||||
struct Evaluable
|
||||
{
|
||||
@@ -179,14 +186,14 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
T m_x;
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
Sum(LhsT lhs, RhsT rhs) :
|
||||
m_lhs(std::move(lhs)),
|
||||
m_rhs(std::move(rhs))
|
||||
Sum(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -203,36 +210,36 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
LhsT m_lhs;
|
||||
RhsT m_rhs;
|
||||
StoreValueOrRef<LhsT> m_lhs;
|
||||
StoreValueOrRef<RhsT> m_rhs;
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator+(LhsT lhs, RhsT rhs)
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator+(LhsT&& lhs, RhsT&& rhs)
|
||||
{
|
||||
return Sum(std::move(lhs), std::move(rhs));
|
||||
return Sum<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator+(LhsT lhs, Id<T> rhs)
|
||||
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator+(LhsT&& lhs, Id<T> rhs)
|
||||
{
|
||||
return Sum(std::move(lhs), Constant(rhs));
|
||||
return Sum<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||
}
|
||||
|
||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
||||
auto operator+(Id<T> lhs, RhsT rhs)
|
||||
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||
auto operator+(Id<T> lhs, RhsT&& rhs)
|
||||
{
|
||||
return Sum(Constant(lhs), std::move(rhs));
|
||||
return Sum<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
Difference(LhsT lhs, RhsT rhs) :
|
||||
m_lhs(std::move(lhs)),
|
||||
m_rhs(std::move(rhs))
|
||||
Difference(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -249,36 +256,36 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
LhsT m_lhs;
|
||||
RhsT m_rhs;
|
||||
StoreValueOrRef<LhsT> m_lhs;
|
||||
StoreValueOrRef<RhsT> m_rhs;
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator-(LhsT lhs, RhsT rhs)
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator-(LhsT&& lhs, RhsT&& rhs)
|
||||
{
|
||||
return Difference(std::move(lhs), std::move(rhs));
|
||||
return Difference<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator-(LhsT lhs, Id<T> rhs)
|
||||
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator-(LhsT&& lhs, Id<T> rhs)
|
||||
{
|
||||
return Difference(std::move(lhs), Constant(rhs));
|
||||
return Difference<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||
}
|
||||
|
||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
||||
auto operator-(Id<T> lhs, RhsT rhs)
|
||||
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||
auto operator-(Id<T> lhs, RhsT&& rhs)
|
||||
{
|
||||
return Difference(Constant(lhs), std::move(rhs));
|
||||
return Difference<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
Product(LhsT lhs, RhsT rhs) :
|
||||
m_lhs(std::move(lhs)),
|
||||
m_rhs(std::move(rhs))
|
||||
Product(LhsT&& lhs, RhsT&& rhs) :
|
||||
m_lhs(std::forward<LhsT>(lhs)),
|
||||
m_rhs(std::forward<RhsT>(rhs))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -295,35 +302,35 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
LhsT m_lhs;
|
||||
RhsT m_rhs;
|
||||
StoreValueOrRef<LhsT> m_lhs;
|
||||
StoreValueOrRef<RhsT> m_rhs;
|
||||
};
|
||||
|
||||
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator*(LhsT lhs, RhsT rhs)
|
||||
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator*(LhsT&& lhs, RhsT&& rhs)
|
||||
{
|
||||
return Product(std::move(lhs), std::move(rhs));
|
||||
return Product<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename LhsT, typename T = typename LhsT::ValueType>
|
||||
auto operator*(LhsT lhs, Id<T> rhs)
|
||||
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
|
||||
auto operator*(LhsT&& lhs, Id<T> rhs)
|
||||
{
|
||||
return Product(std::move(lhs), Constant(rhs));
|
||||
return Product<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
|
||||
}
|
||||
|
||||
template <typename RhsT, typename T = typename RhsT::ValueType>
|
||||
auto operator*(Id<T> lhs, RhsT rhs)
|
||||
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
|
||||
auto operator*(Id<T> lhs, RhsT&& rhs)
|
||||
{
|
||||
return Product(Constant(lhs), std::move(rhs));
|
||||
return Product<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
struct Negation : Evaluable<T, Negation<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
explicit Negation(ArgT x) :
|
||||
m_x(std::move(x))
|
||||
explicit Negation(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -340,22 +347,22 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
ArgT m_x;
|
||||
StoreValueOrRef<ArgT> m_x;
|
||||
};
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
auto operator-(ArgT x)
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
auto operator-(ArgT&& x)
|
||||
{
|
||||
return Negation(std::move(x));
|
||||
return Negation<ArgT&&>(std::forward<ArgT>(x));
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
explicit Sigmoid(ArgT x) :
|
||||
m_x(std::move(x))
|
||||
explicit Sigmoid(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -372,7 +379,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
ArgT m_x;
|
||||
StoreValueOrRef<ArgT> m_x;
|
||||
|
||||
T value_(T x) const
|
||||
{
|
||||
@@ -385,19 +392,19 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
auto sigmoid(ArgT x)
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
auto sigmoid(ArgT&& x)
|
||||
{
|
||||
return Sigmoid(std::move(x));
|
||||
return Sigmoid<ArgT&&>(std::forward<ArgT>(x));
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
struct Pow : Evaluable<T, Pow<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
explicit Pow(ArgT x, Id<T> exponent) :
|
||||
m_x(std::move(x)),
|
||||
explicit Pow(ArgT&& x, Id<T> exponent) :
|
||||
m_x(std::forward<ArgT>(x)),
|
||||
m_exponent(std::move(exponent))
|
||||
{
|
||||
}
|
||||
@@ -415,23 +422,23 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
ArgT m_x;
|
||||
StoreValueOrRef<ArgT> m_x;
|
||||
T m_exponent;
|
||||
};
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
auto pow(ArgT x, Id<T> exp)
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
auto pow(ArgT&& x, Id<T> exp)
|
||||
{
|
||||
return Pow(std::move(x), std::move(exp));
|
||||
return Pow<ArgT&&>(std::forward<ArgT>(x), std::move(exp));
|
||||
}
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
struct Log : Evaluable<T, Log<ArgT, T>>
|
||||
{
|
||||
using ValueType = T;
|
||||
|
||||
explicit Log(ArgT x) :
|
||||
m_x(std::move(x))
|
||||
explicit Log(ArgT&& x) :
|
||||
m_x(std::forward<ArgT>(x))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -448,7 +455,7 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
|
||||
private:
|
||||
ArgT m_x;
|
||||
StoreValueOrRef<ArgT> m_x;
|
||||
|
||||
T value_(T x) const
|
||||
{
|
||||
@@ -461,10 +468,10 @@ namespace Learner::Autograd::UnivariateStatic
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ArgT, typename T = typename ArgT::ValueType>
|
||||
auto log(ArgT x)
|
||||
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
|
||||
auto log(ArgT&& x)
|
||||
{
|
||||
return Log(std::move(x));
|
||||
return Log<ArgT&&>(std::forward<ArgT>(x));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user