From c6a1e7fd4232ec151206fab16cb7daa23bfd7137 Mon Sep 17 00:00:00 2001 From: cj5716 <125858804+cj5716@users.noreply.github.com> Date: Sun, 19 May 2024 13:15:42 +0800 Subject: [PATCH] Optimise pairwise multiplication This speedup was first inspired by a comment by @AndyGrant on my recent PR "If mullo_epi16 would preserve the signedness, then this could be used to remove 50% of the max operations during the halfkp-pairwise mat-mul relu deal." That got me thinking, because although mullo_epi16 did not preserve the signedness, mulhi_epi16 did, and so we could shift left and then use mulhi_epi16, instead of shifting right after the mullo. However, due to some issues with shifting into the sign bit, the FT weights and biases had to be multiplied by 2 for the optimisation to work. Speedup on "Arch=x86-64-bmi2 COMP=clang", courtesy of @Torom Result of 50 runs base (...es/stockfish) = 962946 +/- 1202 test (...ise-max-less) = 979696 +/- 1084 diff = +16750 +/- 1794 speedup = +0.0174 P(speedup > 0) = 1.0000 CPU: 4 x Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz Hyperthreading: on Also a speedup on "COMP=gcc", courtesy of Torom once again Result of 50 runs base (...tockfish_gcc) = 966033 +/- 1574 test (...max-less_gcc) = 983319 +/- 1513 diff = +17286 +/- 2515 speedup = +0.0179 P(speedup > 0) = 1.0000 CPU: 4 x Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz Hyperthreading: on Passed STC: LLR: 2.96 (-2.94,2.94) <0.00,2.00> Total: 67712 W: 17715 L: 17358 D: 32639 Ptnml(0-2): 225, 7472, 18140, 7759, 260 https://tests.stockfishchess.org/tests/view/664c1d75830eb9f886616906 closes https://github.com/official-stockfish/Stockfish/pull/5282 No functional change --- src/nnue/nnue_feature_transformer.h | 80 +++++++++++++++++++---------- 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 7b7aada3..483b84a8 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -55,14 +55,14 @@ using psqt_vec_t = __m256i; #define vec_store(a, b) _mm512_store_si512(a, b) #define vec_add_16(a, b) _mm512_add_epi16(a, b) #define vec_sub_16(a, b) _mm512_sub_epi16(a, b) - #define vec_mul_16(a, b) _mm512_mullo_epi16(a, b) + #define vec_mulhi_16(a, b) _mm512_mulhi_epi16(a, b) #define vec_zero() _mm512_setzero_epi32() #define vec_set_16(a) _mm512_set1_epi16(a) #define vec_max_16(a, b) _mm512_max_epi16(a, b) #define vec_min_16(a, b) _mm512_min_epi16(a, b) + #define vec_slli_16(a, b) _mm512_slli_epi16(a, b) // Inverse permuted at load time - #define vec_msb_pack_16(a, b) \ - _mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7)) + #define vec_packus_16(a, b) _mm512_packus_epi16(a, b) #define vec_load_psqt(a) _mm256_load_si256(a) #define vec_store_psqt(a, b) _mm256_store_si256(a, b) #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) @@ -78,14 +78,14 @@ using psqt_vec_t = __m256i; #define vec_store(a, b) _mm256_store_si256(a, b) #define vec_add_16(a, b) _mm256_add_epi16(a, b) #define vec_sub_16(a, b) _mm256_sub_epi16(a, b) - #define vec_mul_16(a, b) _mm256_mullo_epi16(a, b) + #define vec_mulhi_16(a, b) _mm256_mulhi_epi16(a, b) #define vec_zero() _mm256_setzero_si256() #define vec_set_16(a) _mm256_set1_epi16(a) #define vec_max_16(a, b) _mm256_max_epi16(a, b) #define vec_min_16(a, b) _mm256_min_epi16(a, b) + #define vec_slli_16(a, b) _mm256_slli_epi16(a, b) // Inverse permuted at load time - #define vec_msb_pack_16(a, b) \ - _mm256_packs_epi16(_mm256_srli_epi16(a, 7), _mm256_srli_epi16(b, 7)) + #define vec_packus_16(a, b) _mm256_packus_epi16(a, b) #define vec_load_psqt(a) _mm256_load_si256(a) #define vec_store_psqt(a, b) _mm256_store_si256(a, b) #define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b) @@ -101,12 +101,13 @@ using psqt_vec_t = __m128i; #define vec_store(a, b) *(a) = (b) #define vec_add_16(a, b) _mm_add_epi16(a, b) #define vec_sub_16(a, b) _mm_sub_epi16(a, b) - #define vec_mul_16(a, b) _mm_mullo_epi16(a, b) + #define vec_mulhi_16(a, b) _mm_mulhi_epi16(a, b) #define vec_zero() _mm_setzero_si128() #define vec_set_16(a) _mm_set1_epi16(a) #define vec_max_16(a, b) _mm_max_epi16(a, b) #define vec_min_16(a, b) _mm_min_epi16(a, b) - #define vec_msb_pack_16(a, b) _mm_packs_epi16(_mm_srli_epi16(a, 7), _mm_srli_epi16(b, 7)) + #define vec_slli_16(a, b) _mm_slli_epi16(a, b) + #define vec_packus_16(a, b) _mm_packus_epi16(a, b) #define vec_load_psqt(a) (*(a)) #define vec_store_psqt(a, b) *(a) = (b) #define vec_add_psqt_32(a, b) _mm_add_epi32(a, b) @@ -122,18 +123,14 @@ using psqt_vec_t = int32x4_t; #define vec_store(a, b) *(a) = (b) #define vec_add_16(a, b) vaddq_s16(a, b) #define vec_sub_16(a, b) vsubq_s16(a, b) - #define vec_mul_16(a, b) vmulq_s16(a, b) + #define vec_mulhi_16(a, b) vqdmulhq_s16(a, b) #define vec_zero() \ vec_t { 0 } #define vec_set_16(a) vdupq_n_s16(a) #define vec_max_16(a, b) vmaxq_s16(a, b) #define vec_min_16(a, b) vminq_s16(a, b) -inline vec_t vec_msb_pack_16(vec_t a, vec_t b) { - const int8x8_t shifta = vshrn_n_s16(a, 7); - const int8x8_t shiftb = vshrn_n_s16(b, 7); - const int8x16_t compacted = vcombine_s8(shifta, shiftb); - return *reinterpret_cast(&compacted); -} + #define vec_slli_16(a, b) vshlq_s16(a, vec_set_16(b)) + #define vec_packus_16(a, b) reinterpret_cast(vcombine_u8(vqmovun_s16(a), vqmovun_s16(b))) #define vec_load_psqt(a) (*(a)) #define vec_store_psqt(a, b) *(a) = (b) #define vec_add_psqt_32(a, b) vaddq_s32(a, b) @@ -281,6 +278,19 @@ class FeatureTransformer { #endif } + inline void scale_weights(bool read) const { + for (IndexType j = 0; j < InputDimensions; ++j) + { + WeightType* w = const_cast(&weights[j * HalfDimensions]); + for (IndexType i = 0; i < HalfDimensions; ++i) + w[i] = read ? w[i] * 2 : w[i] / 2; + } + + BiasType* b = const_cast(biases); + for (IndexType i = 0; i < HalfDimensions; ++i) + b[i] = read ? b[i] * 2 : b[i] / 2; + } + // Read network parameters bool read_parameters(std::istream& stream) { @@ -289,6 +299,7 @@ class FeatureTransformer { read_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); permute_weights(inverse_order_packs); + scale_weights(true); return !stream.fail(); } @@ -296,12 +307,14 @@ class FeatureTransformer { bool write_parameters(std::ostream& stream) const { permute_weights(order_packs); + scale_weights(false); write_leb_128(stream, biases, HalfDimensions); write_leb_128(stream, weights, HalfDimensions * InputDimensions); write_leb_128(stream, psqtWeights, PSQTBuckets * InputDimensions); permute_weights(inverse_order_packs); + scale_weights(true); return !stream.fail(); } @@ -332,7 +345,7 @@ class FeatureTransformer { constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize; const vec_t Zero = vec_zero(); - const vec_t One = vec_set_16(127); + const vec_t One = vec_set_16(127 * 2); const vec_t* in0 = reinterpret_cast(&(accumulation[perspectives[p]][0])); const vec_t* in1 = @@ -341,15 +354,30 @@ class FeatureTransformer { for (IndexType j = 0; j < NumOutputChunks; ++j) { - const vec_t sum0a = vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero); - const vec_t sum0b = vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero); - const vec_t sum1a = vec_max_16(vec_min_16(in1[j * 2 + 0], One), Zero); - const vec_t sum1b = vec_max_16(vec_min_16(in1[j * 2 + 1], One), Zero); + // What we want to do is multiply inputs in a pairwise manner (after clipping), and then shift right by 9. + // Instead, we shift left by 7, and use mulhi, stripping the bottom 16 bits, effectively shifting right by 16, + // resulting in a net shift of 9 bits. We use mulhi because it maintains the sign of the multiplication (unlike mullo), + // allowing us to make use of packus to clip 2 of the inputs, resulting in a save of 2 "vec_max_16" calls. + // A special case is when we use NEON, where we shift left by 6 instead, because the instruction "vqdmulhq_s16" + // also doubles the return value after the multiplication, adding an extra shift to the left by 1, so we + // compensate by shifting less before the multiplication. - const vec_t pa = vec_mul_16(sum0a, sum1a); - const vec_t pb = vec_mul_16(sum0b, sum1b); + #if defined(USE_SSE2) + constexpr int shift = 7; + #else + constexpr int shift = 6; + #endif + const vec_t sum0a = + vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift); + const vec_t sum0b = + vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero), shift); + const vec_t sum1a = vec_min_16(in1[j * 2 + 0], One); + const vec_t sum1b = vec_min_16(in1[j * 2 + 1], One); - out[j] = vec_msb_pack_16(pa, pb); + const vec_t pa = vec_mulhi_16(sum0a, sum1a); + const vec_t pb = vec_mulhi_16(sum0b, sum1b); + + out[j] = vec_packus_16(pa, pb); } #else @@ -359,9 +387,9 @@ class FeatureTransformer { BiasType sum0 = accumulation[static_cast(perspectives[p])][j + 0]; BiasType sum1 = accumulation[static_cast(perspectives[p])][j + HalfDimensions / 2]; - sum0 = std::clamp(sum0, 0, 127); - sum1 = std::clamp(sum1, 0, 127); - output[offset + j] = static_cast(unsigned(sum0 * sum1) / 128); + sum0 = std::clamp(sum0, 0, 127 * 2); + sum1 = std::clamp(sum1, 0, 127 * 2); + output[offset + j] = static_cast(unsigned(sum0 * sum1) / 512); } #endif