When forming an autograd expression only copy parts that are rvalue references, store references to lvalues.

This commit is contained in:
Tomasz Sobczyk
2020-11-29 19:06:31 +01:00
committed by nodchip
parent a5c20bee5b
commit aec6017195

View File

@@ -69,6 +69,13 @@ namespace Learner::Autograd::UnivariateStatic
template <typename T>
using Id = typename Identity<T>::type;
template <typename T>
using StoreValueOrRef = std::conditional_t<
std::is_rvalue_reference_v<T>,
std::remove_reference_t<T>,
const std::remove_reference_t<T>&
>;
template <typename T, typename ChildT>
struct Evaluable
{
@@ -179,14 +186,14 @@ namespace Learner::Autograd::UnivariateStatic
T m_x;
};
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Sum : Evaluable<T, Sum<LhsT, RhsT, T>>
{
using ValueType = T;
Sum(LhsT lhs, RhsT rhs) :
m_lhs(std::move(lhs)),
m_rhs(std::move(rhs))
Sum(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
{
}
@@ -203,36 +210,36 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
LhsT m_lhs;
RhsT m_rhs;
StoreValueOrRef<LhsT> m_lhs;
StoreValueOrRef<RhsT> m_rhs;
};
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
auto operator+(LhsT lhs, RhsT rhs)
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator+(LhsT&& lhs, RhsT&& rhs)
{
return Sum(std::move(lhs), std::move(rhs));
return Sum<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
}
template <typename LhsT, typename T = typename LhsT::ValueType>
auto operator+(LhsT lhs, Id<T> rhs)
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator+(LhsT&& lhs, Id<T> rhs)
{
return Sum(std::move(lhs), Constant(rhs));
return Sum<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
}
template <typename RhsT, typename T = typename RhsT::ValueType>
auto operator+(Id<T> lhs, RhsT rhs)
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator+(Id<T> lhs, RhsT&& rhs)
{
return Sum(Constant(lhs), std::move(rhs));
return Sum<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
}
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Difference : Evaluable<T, Difference<LhsT, RhsT, T>>
{
using ValueType = T;
Difference(LhsT lhs, RhsT rhs) :
m_lhs(std::move(lhs)),
m_rhs(std::move(rhs))
Difference(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
{
}
@@ -249,36 +256,36 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
LhsT m_lhs;
RhsT m_rhs;
StoreValueOrRef<LhsT> m_lhs;
StoreValueOrRef<RhsT> m_rhs;
};
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
auto operator-(LhsT lhs, RhsT rhs)
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator-(LhsT&& lhs, RhsT&& rhs)
{
return Difference(std::move(lhs), std::move(rhs));
return Difference<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
}
template <typename LhsT, typename T = typename LhsT::ValueType>
auto operator-(LhsT lhs, Id<T> rhs)
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator-(LhsT&& lhs, Id<T> rhs)
{
return Difference(std::move(lhs), Constant(rhs));
return Difference<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
}
template <typename RhsT, typename T = typename RhsT::ValueType>
auto operator-(Id<T> lhs, RhsT rhs)
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator-(Id<T> lhs, RhsT&& rhs)
{
return Difference(Constant(lhs), std::move(rhs));
return Difference<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
}
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
struct Product : Evaluable<T, Product<LhsT, RhsT, T>>
{
using ValueType = T;
Product(LhsT lhs, RhsT rhs) :
m_lhs(std::move(lhs)),
m_rhs(std::move(rhs))
Product(LhsT&& lhs, RhsT&& rhs) :
m_lhs(std::forward<LhsT>(lhs)),
m_rhs(std::forward<RhsT>(rhs))
{
}
@@ -295,35 +302,35 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
LhsT m_lhs;
RhsT m_rhs;
StoreValueOrRef<LhsT> m_lhs;
StoreValueOrRef<RhsT> m_rhs;
};
template <typename LhsT, typename RhsT, typename T = typename LhsT::ValueType>
auto operator*(LhsT lhs, RhsT rhs)
template <typename LhsT, typename RhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator*(LhsT&& lhs, RhsT&& rhs)
{
return Product(std::move(lhs), std::move(rhs));
return Product<LhsT&&, RhsT&&>(std::forward<LhsT>(lhs), std::forward<RhsT>(rhs));
}
template <typename LhsT, typename T = typename LhsT::ValueType>
auto operator*(LhsT lhs, Id<T> rhs)
template <typename LhsT, typename T = typename std::remove_reference_t<LhsT>::ValueType>
auto operator*(LhsT&& lhs, Id<T> rhs)
{
return Product(std::move(lhs), Constant(rhs));
return Product<LhsT&&, Constant<T>&&>(std::forward<LhsT>(lhs), Constant(rhs));
}
template <typename RhsT, typename T = typename RhsT::ValueType>
auto operator*(Id<T> lhs, RhsT rhs)
template <typename RhsT, typename T = typename std::remove_reference_t<RhsT>::ValueType>
auto operator*(Id<T> lhs, RhsT&& rhs)
{
return Product(Constant(lhs), std::move(rhs));
return Product<Constant<T>&&, RhsT&&>(Constant(lhs), std::forward<RhsT>(rhs));
}
template <typename ArgT, typename T = typename ArgT::ValueType>
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Negation : Evaluable<T, Negation<ArgT, T>>
{
using ValueType = T;
explicit Negation(ArgT x) :
m_x(std::move(x))
explicit Negation(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{
}
@@ -340,22 +347,22 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
ArgT m_x;
StoreValueOrRef<ArgT> m_x;
};
template <typename ArgT, typename T = typename ArgT::ValueType>
auto operator-(ArgT x)
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto operator-(ArgT&& x)
{
return Negation(std::move(x));
return Negation<ArgT&&>(std::forward<ArgT>(x));
}
template <typename ArgT, typename T = typename ArgT::ValueType>
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Sigmoid : Evaluable<T, Sigmoid<ArgT, T>>
{
using ValueType = T;
explicit Sigmoid(ArgT x) :
m_x(std::move(x))
explicit Sigmoid(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{
}
@@ -372,7 +379,7 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
ArgT m_x;
StoreValueOrRef<ArgT> m_x;
T value_(T x) const
{
@@ -385,19 +392,19 @@ namespace Learner::Autograd::UnivariateStatic
}
};
template <typename ArgT, typename T = typename ArgT::ValueType>
auto sigmoid(ArgT x)
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto sigmoid(ArgT&& x)
{
return Sigmoid(std::move(x));
return Sigmoid<ArgT&&>(std::forward<ArgT>(x));
}
template <typename ArgT, typename T = typename ArgT::ValueType>
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Pow : Evaluable<T, Pow<ArgT, T>>
{
using ValueType = T;
explicit Pow(ArgT x, Id<T> exponent) :
m_x(std::move(x)),
explicit Pow(ArgT&& x, Id<T> exponent) :
m_x(std::forward<ArgT>(x)),
m_exponent(std::move(exponent))
{
}
@@ -415,23 +422,23 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
ArgT m_x;
StoreValueOrRef<ArgT> m_x;
T m_exponent;
};
template <typename ArgT, typename T = typename ArgT::ValueType>
auto pow(ArgT x, Id<T> exp)
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto pow(ArgT&& x, Id<T> exp)
{
return Pow(std::move(x), std::move(exp));
return Pow<ArgT&&>(std::forward<ArgT>(x), std::move(exp));
}
template <typename ArgT, typename T = typename ArgT::ValueType>
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
struct Log : Evaluable<T, Log<ArgT, T>>
{
using ValueType = T;
explicit Log(ArgT x) :
m_x(std::move(x))
explicit Log(ArgT&& x) :
m_x(std::forward<ArgT>(x))
{
}
@@ -448,7 +455,7 @@ namespace Learner::Autograd::UnivariateStatic
}
private:
ArgT m_x;
StoreValueOrRef<ArgT> m_x;
T value_(T x) const
{
@@ -461,10 +468,10 @@ namespace Learner::Autograd::UnivariateStatic
}
};
template <typename ArgT, typename T = typename ArgT::ValueType>
auto log(ArgT x)
template <typename ArgT, typename T = typename std::remove_reference_t<ArgT>::ValueType>
auto log(ArgT&& x)
{
return Log(std::move(x));
return Log<ArgT&&>(std::forward<ArgT>(x));
}
}