diff --git a/src/nnue/network.cpp b/src/nnue/network.cpp index cba3abc6..e23294e4 100644 --- a/src/nnue/network.cpp +++ b/src/nnue/network.cpp @@ -219,13 +219,13 @@ Network::evaluate(const Position& pos #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) TransformedFeatureType - transformedFeaturesUnaligned[FeatureTransformer::BufferSize + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); #else - alignas(alignment) TransformedFeatureType - transformedFeatures[FeatureTransformer::BufferSize]; + alignas(alignment) + TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; #endif ASSERT_ALIGNED(transformedFeatures, alignment); @@ -290,13 +290,13 @@ Network::trace_evaluate(const Position& #if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN) TransformedFeatureType - transformedFeaturesUnaligned[FeatureTransformer::BufferSize + transformedFeaturesUnaligned[FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)]; auto* transformedFeatures = align_ptr_up(&transformedFeaturesUnaligned[0]); #else - alignas(alignment) TransformedFeatureType - transformedFeatures[FeatureTransformer::BufferSize]; + alignas(alignment) + TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize]; #endif ASSERT_ALIGNED(transformedFeatures, alignment); @@ -452,12 +452,10 @@ bool Network::write_parameters(std::ostream& stream, // Explicit template instantiations -template class Network< - NetworkArchitecture, - FeatureTransformer>; +template class Network, + FeatureTransformer>; -template class Network< - NetworkArchitecture, - FeatureTransformer>; +template class Network, + FeatureTransformer>; } // namespace Stockfish::Eval::NNUE diff --git a/src/nnue/network.h b/src/nnue/network.h index 21df4b0a..cd32c531 100644 --- a/src/nnue/network.h +++ b/src/nnue/network.h @@ -110,13 +110,11 @@ class Network { }; // Definitions of the network types -using SmallFeatureTransformer = - FeatureTransformer; +using SmallFeatureTransformer = FeatureTransformer; using SmallNetworkArchitecture = NetworkArchitecture; -using BigFeatureTransformer = - FeatureTransformer; +using BigFeatureTransformer = FeatureTransformer; using BigNetworkArchitecture = NetworkArchitecture; using NetworkBig = Network; diff --git a/src/nnue/nnue_accumulator.cpp b/src/nnue/nnue_accumulator.cpp index efa8df90..37af7a0f 100644 --- a/src/nnue/nnue_accumulator.cpp +++ b/src/nnue/nnue_accumulator.cpp @@ -48,22 +48,20 @@ namespace Stockfish::Eval::NNUE { namespace { -template AccumulatorState::*accPtr> +template void update_accumulator_incremental( - const FeatureTransformer& featureTransformer, - const Square ksq, - AccumulatorState& target_state, - const AccumulatorState& computed); + const FeatureTransformer& featureTransformer, + const Square ksq, + AccumulatorState& target_state, + const AccumulatorState& computed); -template AccumulatorState::*accPtr> -void update_accumulator_refresh_cache( - const FeatureTransformer& featureTransformer, - const Position& pos, - AccumulatorState& accumulatorState, - AccumulatorCaches::Cache& cache); +template +void update_accumulator_refresh_cache(const FeatureTransformer& featureTransformer, + const Position& pos, + AccumulatorState& accumulatorState, + AccumulatorCaches::Cache& cache); } @@ -86,18 +84,14 @@ void AccumulatorStack::reset(const Position& rootPos, AccumulatorCaches& caches) noexcept { m_current_idx = 1; - update_accumulator_refresh_cache( + update_accumulator_refresh_cache( *networks.big.featureTransformer, rootPos, m_accumulators[0], caches.big); - update_accumulator_refresh_cache( + update_accumulator_refresh_cache( *networks.big.featureTransformer, rootPos, m_accumulators[0], caches.big); - update_accumulator_refresh_cache( + update_accumulator_refresh_cache( *networks.small.featureTransformer, rootPos, m_accumulators[0], caches.small); - update_accumulator_refresh_cache( + update_accumulator_refresh_cache( *networks.small.featureTransformer, rootPos, m_accumulators[0], caches.small); } @@ -112,24 +106,23 @@ void AccumulatorStack::pop() noexcept { m_current_idx--; } -template AccumulatorState::*accPtr> -void AccumulatorStack::evaluate(const Position& pos, - const FeatureTransformer& featureTransformer, - AccumulatorCaches::Cache& cache) noexcept { +template +void AccumulatorStack::evaluate(const Position& pos, + const FeatureTransformer& featureTransformer, + AccumulatorCaches::Cache& cache) noexcept { evaluate_side(pos, featureTransformer, cache); evaluate_side(pos, featureTransformer, cache); } -template AccumulatorState::*accPtr> -void AccumulatorStack::evaluate_side( - const Position& pos, - const FeatureTransformer& featureTransformer, - AccumulatorCaches::Cache& cache) noexcept { +template +void AccumulatorStack::evaluate_side(const Position& pos, + const FeatureTransformer& featureTransformer, + AccumulatorCaches::Cache& cache) noexcept { - const auto last_usable_accum = find_last_usable_accumulator(); + const auto last_usable_accum = find_last_usable_accumulator(); - if ((m_accumulators[last_usable_accum].*accPtr).computed[Perspective]) + if ((m_accumulators[last_usable_accum].template acc()).computed[Perspective]) forward_update_incremental(pos, featureTransformer, last_usable_accum); else @@ -141,12 +134,12 @@ void AccumulatorStack::evaluate_side( // Find the earliest usable accumulator, this can either be a computed accumulator or the accumulator // state just before a change that requires full refresh. -template AccumulatorState::*accPtr> +template std::size_t AccumulatorStack::find_last_usable_accumulator() const noexcept { for (std::size_t curr_idx = m_current_idx - 1; curr_idx > 0; curr_idx--) { - if ((m_accumulators[curr_idx].*accPtr).computed[Perspective]) + if ((m_accumulators[curr_idx].template acc()).computed[Perspective]) return curr_idx; if (FeatureSet::requires_refresh(m_accumulators[curr_idx].dirtyPiece, Perspective)) @@ -156,14 +149,14 @@ std::size_t AccumulatorStack::find_last_usable_accumulator() const noexcept { return 0; } -template AccumulatorState::*accPtr> +template void AccumulatorStack::forward_update_incremental( - const Position& pos, - const FeatureTransformer& featureTransformer, - const std::size_t begin) noexcept { + const Position& pos, + const FeatureTransformer& featureTransformer, + const std::size_t begin) noexcept { assert(begin < m_accumulators.size()); - assert((m_accumulators[begin].*accPtr).computed[Perspective]); + assert((m_accumulators[begin].acc()).computed[Perspective]); const Square ksq = pos.square(Perspective); @@ -171,18 +164,18 @@ void AccumulatorStack::forward_update_incremental( update_accumulator_incremental(featureTransformer, ksq, m_accumulators[next], m_accumulators[next - 1]); - assert((latest().*accPtr).computed[Perspective]); + assert((latest().acc()).computed[Perspective]); } -template AccumulatorState::*accPtr> +template void AccumulatorStack::backward_update_incremental( - const Position& pos, - const FeatureTransformer& featureTransformer, - const std::size_t end) noexcept { + const Position& pos, + const FeatureTransformer& featureTransformer, + const std::size_t end) noexcept { assert(end < m_accumulators.size()); assert(end < m_current_idx); - assert((latest().*accPtr).computed[Perspective]); + assert((latest().acc()).computed[Perspective]); const Square ksq = pos.square(Perspective); @@ -190,21 +183,17 @@ void AccumulatorStack::backward_update_incremental( update_accumulator_incremental( featureTransformer, ksq, m_accumulators[next], m_accumulators[next + 1]); - assert((m_accumulators[end].*accPtr).computed[Perspective]); + assert((m_accumulators[end].acc()).computed[Perspective]); } // Explicit template instantiations -template void -AccumulatorStack::evaluate( - const Position& pos, - const FeatureTransformer& - featureTransformer, +template void AccumulatorStack::evaluate( + const Position& pos, + const FeatureTransformer& featureTransformer, AccumulatorCaches::Cache& cache) noexcept; -template void -AccumulatorStack::evaluate( - const Position& pos, - const FeatureTransformer& - featureTransformer, +template void AccumulatorStack::evaluate( + const Position& pos, + const FeatureTransformer& featureTransformer, AccumulatorCaches::Cache& cache) noexcept; @@ -227,15 +216,15 @@ void fused_row_reduce(const ElementType* in, ElementType* out, const Ts* const.. vecIn[i], reinterpret_cast(rows)[i]...); } -template AccumulatorState::*accPtr> +template struct AccumulatorUpdateContext { - const FeatureTransformer& featureTransformer; - const AccumulatorState& from; - AccumulatorState& to; + const FeatureTransformer& featureTransformer; + const AccumulatorState& from; + AccumulatorState& to; - AccumulatorUpdateContext(const FeatureTransformer& ft, - const AccumulatorState& accF, - AccumulatorState& accT) noexcept : + AccumulatorUpdateContext(const FeatureTransformer& ft, + const AccumulatorState& accF, + AccumulatorState& accT) noexcept : featureTransformer{ft}, from{accF}, to{accT} {} @@ -252,41 +241,37 @@ struct AccumulatorUpdateContext { return &featureTransformer.psqtWeights[index * PSQTBuckets]; }; - fused_row_reduce((from.*accPtr).accumulation[Perspective], - (to.*accPtr).accumulation[Perspective], - to_weight_vector(indices)...); + fused_row_reduce( + (from.acc()).accumulation[Perspective], + (to.acc()).accumulation[Perspective], to_weight_vector(indices)...); fused_row_reduce( - (from.*accPtr).psqtAccumulation[Perspective], (to.*accPtr).psqtAccumulation[Perspective], - to_psqt_weight_vector(indices)...); + (from.acc()).psqtAccumulation[Perspective], + (to.acc()).psqtAccumulation[Perspective], to_psqt_weight_vector(indices)...); } }; -template AccumulatorState::*accPtr> -auto make_accumulator_update_context( - const FeatureTransformer& featureTransformer, - const AccumulatorState& accumulatorFrom, - AccumulatorState& accumulatorTo) noexcept { - return AccumulatorUpdateContext{ - featureTransformer, accumulatorFrom, accumulatorTo}; +template +auto make_accumulator_update_context(const FeatureTransformer& featureTransformer, + const AccumulatorState& accumulatorFrom, + AccumulatorState& accumulatorTo) noexcept { + return AccumulatorUpdateContext{featureTransformer, accumulatorFrom, + accumulatorTo}; } -template AccumulatorState::*accPtr> +template void update_accumulator_incremental( - const FeatureTransformer& featureTransformer, - const Square ksq, - AccumulatorState& target_state, - const AccumulatorState& computed) { + const FeatureTransformer& featureTransformer, + const Square ksq, + AccumulatorState& target_state, + const AccumulatorState& computed) { [[maybe_unused]] constexpr bool Forward = Direction == FORWARD; [[maybe_unused]] constexpr bool Backward = Direction == BACKWARD; assert(Forward != Backward); - assert((computed.*accPtr).computed[Perspective]); - assert(!(target_state.*accPtr).computed[Perspective]); + assert((computed.acc()).computed[Perspective]); + assert(!(target_state.acc()).computed[Perspective]); // The size must be enough to contain the largest possible update. // That might depend on the feature set and generally relies on the @@ -340,15 +325,14 @@ void update_accumulator_incremental( removed[1]); } - (target_state.*accPtr).computed[Perspective] = true; + (target_state.acc()).computed[Perspective] = true; } -template AccumulatorState::*accPtr> -void update_accumulator_refresh_cache( - const FeatureTransformer& featureTransformer, - const Position& pos, - AccumulatorState& accumulatorState, - AccumulatorCaches::Cache& cache) { +template +void update_accumulator_refresh_cache(const FeatureTransformer& featureTransformer, + const Position& pos, + AccumulatorState& accumulatorState, + AccumulatorCaches::Cache& cache) { using Tiling [[maybe_unused]] = SIMDTiling; const Square ksq = pos.square(Perspective); @@ -378,7 +362,7 @@ void update_accumulator_refresh_cache( } } - auto& accumulator = accumulatorState.*accPtr; + auto& accumulator = accumulatorState.acc(); accumulator.computed[Perspective] = true; #ifdef VECTOR diff --git a/src/nnue/nnue_accumulator.h b/src/nnue/nnue_accumulator.h index 362ea83e..d83a5a44 100644 --- a/src/nnue/nnue_accumulator.h +++ b/src/nnue/nnue_accumulator.h @@ -46,10 +46,7 @@ struct Networks; template struct alignas(CacheLineSize) Accumulator; -struct AccumulatorState; - -template AccumulatorState::*accPtr> +template class FeatureTransformer; // Class that holds the result of affine transformation of input features @@ -121,6 +118,30 @@ struct AccumulatorState { Accumulator accumulatorSmall; DirtyPiece dirtyPiece; + template + auto& acc() noexcept { + static_assert(Size == TransformedFeatureDimensionsBig + || Size == TransformedFeatureDimensionsSmall, + "Invalid size for accumulator"); + + if constexpr (Size == TransformedFeatureDimensionsBig) + return accumulatorBig; + else if constexpr (Size == TransformedFeatureDimensionsSmall) + return accumulatorSmall; + } + + template + const auto& acc() const noexcept { + static_assert(Size == TransformedFeatureDimensionsBig + || Size == TransformedFeatureDimensionsSmall, + "Invalid size for accumulator"); + + if constexpr (Size == TransformedFeatureDimensionsBig) + return accumulatorBig; + else if constexpr (Size == TransformedFeatureDimensionsSmall) + return accumulatorSmall; + } + void reset(const DirtyPiece& dp) noexcept; }; @@ -138,41 +159,31 @@ class AccumulatorStack { void push(const DirtyPiece& dirtyPiece) noexcept; void pop() noexcept; - template AccumulatorState::*accPtr> - void evaluate(const Position& pos, - const FeatureTransformer& featureTransformer, - AccumulatorCaches::Cache& cache) noexcept; + template + void evaluate(const Position& pos, + const FeatureTransformer& featureTransformer, + AccumulatorCaches::Cache& cache) noexcept; private: [[nodiscard]] AccumulatorState& mut_latest() noexcept; - template AccumulatorState::*accPtr> - void evaluate_side(const Position& pos, - const FeatureTransformer& featureTransformer, - AccumulatorCaches::Cache& cache) noexcept; + template + void evaluate_side(const Position& pos, + const FeatureTransformer& featureTransformer, + AccumulatorCaches::Cache& cache) noexcept; - template AccumulatorState::*accPtr> + template [[nodiscard]] std::size_t find_last_usable_accumulator() const noexcept; - template AccumulatorState::*accPtr> - void - forward_update_incremental(const Position& pos, - const FeatureTransformer& featureTransformer, - const std::size_t begin) noexcept; + template + void forward_update_incremental(const Position& pos, + const FeatureTransformer& featureTransformer, + const std::size_t begin) noexcept; - template AccumulatorState::*accPtr> - void - backward_update_incremental(const Position& pos, - const FeatureTransformer& featureTransformer, - const std::size_t end) noexcept; + template + void backward_update_incremental(const Position& pos, + const FeatureTransformer& featureTransformer, + const std::size_t end) noexcept; std::vector m_accumulators; std::size_t m_current_idx; diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 9dee29c1..d2abd40f 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -294,8 +294,7 @@ class SIMDTiling { // Input feature converter -template AccumulatorState::*accPtr> +template class FeatureTransformer { // Number of output dimensions for one side @@ -400,12 +399,12 @@ class FeatureTransformer { const auto& accumulatorState = accumulatorStack.latest(); const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()}; - const auto& psqtAccumulation = (accumulatorState.*accPtr).psqtAccumulation; + const auto& psqtAccumulation = (accumulatorState.acc()).psqtAccumulation; const auto psqt = (psqtAccumulation[perspectives[0]][bucket] - psqtAccumulation[perspectives[1]][bucket]) / 2; - const auto& accumulation = (accumulatorState.*accPtr).accumulation; + const auto& accumulation = (accumulatorState.acc()).accumulation; for (IndexType p = 0; p < 2; ++p) {