From ce7254b5ea3b9b7accb37c8e07fb64eeb5fcfcfa Mon Sep 17 00:00:00 2001 From: mstembera Date: Tue, 24 Jun 2025 15:09:51 -0700 Subject: [PATCH] Optimize find_nnz() using AVX512 About a 1% speedup for ARCH x86-64-avx512 and x86-64-vnni512. Note: This could be optimized further if we wanted to add an ARCH supporting VBMI2 which is even more modern than VNNI. https://en.wikichip.org/wiki/x86/avx512_vbmi2 closes https://github.com/official-stockfish/Stockfish/pull/6139 No functional change --- src/Makefile | 2 +- .../layers/affine_transform_sparse_input.h | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/Makefile b/src/Makefile index 50bb2082..14c3c50c 100644 --- a/src/Makefile +++ b/src/Makefile @@ -701,7 +701,7 @@ endif ifeq ($(avx512),yes) CXXFLAGS += -DUSE_AVX512 ifeq ($(comp),$(filter $(comp),gcc clang mingw icx)) - CXXFLAGS += -mavx512f -mavx512bw + CXXFLAGS += -mavx512f -mavx512bw -mavx512dq -mavx512vl endif endif diff --git a/src/nnue/layers/affine_transform_sparse_input.h b/src/nnue/layers/affine_transform_sparse_input.h index 51f86fd6..e77c98f8 100644 --- a/src/nnue/layers/affine_transform_sparse_input.h +++ b/src/nnue/layers/affine_transform_sparse_input.h @@ -68,9 +68,42 @@ alignas(CacheLineSize) static constexpr struct OffsetIndices { } Lookup; + #if defined(__GNUC__) || defined(__clang__) + #define RESTRICT __restrict__ + #elif defined(_MSC_VER) + #define RESTRICT __restrict + #else + #define RESTRICT + #endif + // Find indices of nonzero numbers in an int32_t array template -void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_out) { +void find_nnz(const std::int32_t* RESTRICT input, + std::uint16_t* RESTRICT out, + IndexType& count_out) { + + #ifdef USE_AVX512 + constexpr IndexType SimdWidth = 16; // 512 bits / 32 bits + constexpr IndexType NumChunks = InputDimensions / SimdWidth; + const __m512i increment = _mm512_set1_epi32(SimdWidth); + __m512i base = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + + IndexType count = 0; + for (IndexType i = 0; i < NumChunks; ++i) + { + const __m512i inputV = _mm512_load_si512(input + i * SimdWidth); + + // Get a bitmask and gather non zero indices + const __mmask16 nnzMask = _mm512_test_epi32_mask(inputV, inputV); + const __m512i nnzV = _mm512_maskz_compress_epi32(nnzMask, base); + _mm512_mask_cvtepi32_storeu_epi16(out + count, 0xFFFF, nnzV); + count += popcount(nnzMask); + base = _mm512_add_epi32(base, increment); + } + count_out = count; + + #else + using namespace SIMD; constexpr IndexType InputSimdWidth = sizeof(vec_uint_t) / sizeof(std::int32_t); @@ -104,6 +137,7 @@ void find_nnz(const std::int32_t* input, std::uint16_t* out, IndexType& count_ou } } count_out = count; + #endif } #endif