diff --git a/src/learn/autograd.h b/src/learn/autograd.h index afbcc41b..714f741a 100644 --- a/src/learn/autograd.h +++ b/src/learn/autograd.h @@ -69,6 +69,13 @@ namespace Learner::Autograd::UnivariateStatic template using Id = typename Identity::type; + template + using StoreValueOrRef = std::conditional_t< + std::is_rvalue_reference_v, + std::remove_reference_t, + const std::remove_reference_t& + >; + template struct Evaluable { @@ -179,14 +186,14 @@ namespace Learner::Autograd::UnivariateStatic T m_x; }; - template + template ::ValueType> struct Sum : Evaluable> { using ValueType = T; - Sum(LhsT lhs, RhsT rhs) : - m_lhs(std::move(lhs)), - m_rhs(std::move(rhs)) + Sum(LhsT&& lhs, RhsT&& rhs) : + m_lhs(std::forward(lhs)), + m_rhs(std::forward(rhs)) { } @@ -203,36 +210,36 @@ namespace Learner::Autograd::UnivariateStatic } private: - LhsT m_lhs; - RhsT m_rhs; + StoreValueOrRef m_lhs; + StoreValueOrRef m_rhs; }; - template - auto operator+(LhsT lhs, RhsT rhs) + template ::ValueType> + auto operator+(LhsT&& lhs, RhsT&& rhs) { - return Sum(std::move(lhs), std::move(rhs)); + return Sum(std::forward(lhs), std::forward(rhs)); } - template - auto operator+(LhsT lhs, Id rhs) + template ::ValueType> + auto operator+(LhsT&& lhs, Id rhs) { - return Sum(std::move(lhs), Constant(rhs)); + return Sum&&>(std::forward(lhs), Constant(rhs)); } - template - auto operator+(Id lhs, RhsT rhs) + template ::ValueType> + auto operator+(Id lhs, RhsT&& rhs) { - return Sum(Constant(lhs), std::move(rhs)); + return Sum&&, RhsT&&>(Constant(lhs), std::forward(rhs)); } - template + template ::ValueType> struct Difference : Evaluable> { using ValueType = T; - Difference(LhsT lhs, RhsT rhs) : - m_lhs(std::move(lhs)), - m_rhs(std::move(rhs)) + Difference(LhsT&& lhs, RhsT&& rhs) : + m_lhs(std::forward(lhs)), + m_rhs(std::forward(rhs)) { } @@ -249,36 +256,36 @@ namespace Learner::Autograd::UnivariateStatic } private: - LhsT m_lhs; - RhsT m_rhs; + StoreValueOrRef m_lhs; + StoreValueOrRef m_rhs; }; - template - auto operator-(LhsT lhs, RhsT rhs) + template ::ValueType> + auto operator-(LhsT&& lhs, RhsT&& rhs) { - return Difference(std::move(lhs), std::move(rhs)); + return Difference(std::forward(lhs), std::forward(rhs)); } - template - auto operator-(LhsT lhs, Id rhs) + template ::ValueType> + auto operator-(LhsT&& lhs, Id rhs) { - return Difference(std::move(lhs), Constant(rhs)); + return Difference&&>(std::forward(lhs), Constant(rhs)); } - template - auto operator-(Id lhs, RhsT rhs) + template ::ValueType> + auto operator-(Id lhs, RhsT&& rhs) { - return Difference(Constant(lhs), std::move(rhs)); + return Difference&&, RhsT&&>(Constant(lhs), std::forward(rhs)); } - template + template ::ValueType> struct Product : Evaluable> { using ValueType = T; - Product(LhsT lhs, RhsT rhs) : - m_lhs(std::move(lhs)), - m_rhs(std::move(rhs)) + Product(LhsT&& lhs, RhsT&& rhs) : + m_lhs(std::forward(lhs)), + m_rhs(std::forward(rhs)) { } @@ -295,35 +302,35 @@ namespace Learner::Autograd::UnivariateStatic } private: - LhsT m_lhs; - RhsT m_rhs; + StoreValueOrRef m_lhs; + StoreValueOrRef m_rhs; }; - template - auto operator*(LhsT lhs, RhsT rhs) + template ::ValueType> + auto operator*(LhsT&& lhs, RhsT&& rhs) { - return Product(std::move(lhs), std::move(rhs)); + return Product(std::forward(lhs), std::forward(rhs)); } - template - auto operator*(LhsT lhs, Id rhs) + template ::ValueType> + auto operator*(LhsT&& lhs, Id rhs) { - return Product(std::move(lhs), Constant(rhs)); + return Product&&>(std::forward(lhs), Constant(rhs)); } - template - auto operator*(Id lhs, RhsT rhs) + template ::ValueType> + auto operator*(Id lhs, RhsT&& rhs) { - return Product(Constant(lhs), std::move(rhs)); + return Product&&, RhsT&&>(Constant(lhs), std::forward(rhs)); } - template + template ::ValueType> struct Negation : Evaluable> { using ValueType = T; - explicit Negation(ArgT x) : - m_x(std::move(x)) + explicit Negation(ArgT&& x) : + m_x(std::forward(x)) { } @@ -340,22 +347,22 @@ namespace Learner::Autograd::UnivariateStatic } private: - ArgT m_x; + StoreValueOrRef m_x; }; - template - auto operator-(ArgT x) + template ::ValueType> + auto operator-(ArgT&& x) { - return Negation(std::move(x)); + return Negation(std::forward(x)); } - template + template ::ValueType> struct Sigmoid : Evaluable> { using ValueType = T; - explicit Sigmoid(ArgT x) : - m_x(std::move(x)) + explicit Sigmoid(ArgT&& x) : + m_x(std::forward(x)) { } @@ -372,7 +379,7 @@ namespace Learner::Autograd::UnivariateStatic } private: - ArgT m_x; + StoreValueOrRef m_x; T value_(T x) const { @@ -385,19 +392,19 @@ namespace Learner::Autograd::UnivariateStatic } }; - template - auto sigmoid(ArgT x) + template ::ValueType> + auto sigmoid(ArgT&& x) { - return Sigmoid(std::move(x)); + return Sigmoid(std::forward(x)); } - template + template ::ValueType> struct Pow : Evaluable> { using ValueType = T; - explicit Pow(ArgT x, Id exponent) : - m_x(std::move(x)), + explicit Pow(ArgT&& x, Id exponent) : + m_x(std::forward(x)), m_exponent(std::move(exponent)) { } @@ -415,23 +422,23 @@ namespace Learner::Autograd::UnivariateStatic } private: - ArgT m_x; + StoreValueOrRef m_x; T m_exponent; }; - template - auto pow(ArgT x, Id exp) + template ::ValueType> + auto pow(ArgT&& x, Id exp) { - return Pow(std::move(x), std::move(exp)); + return Pow(std::forward(x), std::move(exp)); } - template + template ::ValueType> struct Log : Evaluable> { using ValueType = T; - explicit Log(ArgT x) : - m_x(std::move(x)) + explicit Log(ArgT&& x) : + m_x(std::forward(x)) { } @@ -448,7 +455,7 @@ namespace Learner::Autograd::UnivariateStatic } private: - ArgT m_x; + StoreValueOrRef m_x; T value_(T x) const { @@ -461,10 +468,10 @@ namespace Learner::Autograd::UnivariateStatic } }; - template - auto log(ArgT x) + template ::ValueType> + auto log(ArgT&& x) { - return Log(std::move(x)); + return Log(std::forward(x)); } }