Add AVX512 support.

bench: 3909820
This commit is contained in:
mstembera
2020-07-18 19:21:46 -07:00
committed by nodchip
parent 7a13d4ed60
commit 961a4dad5c
2 changed files with 70 additions and 5 deletions

View File

@@ -82,7 +82,11 @@ class AffineTransform {
const auto input = previous_layer_.Propagate(
transformed_features, buffer + kSelfBufferSize);
const auto output = reinterpret_cast<OutputType*>(buffer);
#if defined(USE_AVX2)
#if defined(USE_AVX512)
constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2);
const __m512i kOnes = _mm512_set1_epi16(1);
const auto input_vector = reinterpret_cast<const __m512i*>(input);
#elif defined(USE_AVX2)
constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth;
const __m256i kOnes = _mm256_set1_epi16(1);
const auto input_vector = reinterpret_cast<const __m256i*>(input);
@@ -96,8 +100,43 @@ class AffineTransform {
#endif
for (IndexType i = 0; i < kOutputDimensions; ++i) {
const IndexType offset = i * kPaddedInputDimensions;
#if defined(USE_AVX2)
__m256i sum = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, biases_[i]);
#if defined(USE_AVX512)
__m512i sum = _mm512_setzero_si512();
const auto row = reinterpret_cast<const __m512i*>(&weights_[offset]);
for (IndexType j = 0; j < kNumChunks; ++j) {
#if defined(__MINGW32__) || defined(__MINGW64__)
__m512i product = _mm512_maddubs_epi16(_mm512_loadu_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
#else
__m512i product = _mm512_maddubs_epi16(_mm512_load_si512(&input_vector[j]), _mm512_load_si512(&row[j]));
#endif
product = _mm512_madd_epi16(product, kOnes);
sum = _mm512_add_epi32(sum, product);
}
output[i] = _mm512_reduce_add_epi32(sum) + biases_[i];
// Note: Changing kMaxSimdWidth from 32 to 64 breaks loading existing networks.
// As a result kPaddedInputDimensions may not be an even multiple of 64(512bit)
// and we have to do one more 256bit chunk.
if (kPaddedInputDimensions != kNumChunks * kSimdWidth * 2)
{
const auto iv_256 = reinterpret_cast<const __m256i*>(input);
const auto row_256 = reinterpret_cast<const __m256i*>(&weights_[offset]);
int j = kNumChunks * 2;
#if defined(__MINGW32__) || defined(__MINGW64__) // See HACK comment below in AVX2.
__m256i sum256 = _mm256_maddubs_epi16(_mm256_loadu_si256(&iv_256[j]), _mm256_load_si256(&row_256[j]));
#else
__m256i sum256 = _mm256_maddubs_epi16(_mm256_load_si256(&iv_256[j]), _mm256_load_si256(&row_256[j]));
#endif
sum256 = _mm256_madd_epi16(sum256, _mm256_set1_epi16(1));
sum256 = _mm256_hadd_epi32(sum256, sum256);
sum256 = _mm256_hadd_epi32(sum256, sum256);
const __m128i lo = _mm256_extracti128_si256(sum256, 0);
const __m128i hi = _mm256_extracti128_si256(sum256, 1);
output[i] += _mm_cvtsi128_si32(lo) + _mm_cvtsi128_si32(hi);
}
#elif defined(USE_AVX2)
__m256i sum = _mm256_setzero_si256();
const auto row = reinterpret_cast<const __m256i*>(&weights_[offset]);
for (IndexType j = 0; j < kNumChunks; ++j) {
__m256i product = _mm256_maddubs_epi16(
@@ -117,7 +156,7 @@ class AffineTransform {
sum = _mm256_hadd_epi32(sum, sum);
const __m128i lo = _mm256_extracti128_si256(sum, 0);
const __m128i hi = _mm256_extracti128_si256(sum, 1);
output[i] = _mm_cvtsi128_si32(lo) + _mm_cvtsi128_si32(hi);
output[i] = _mm_cvtsi128_si32(lo) + _mm_cvtsi128_si32(hi) + biases_[i];
#elif defined(USE_SSSE3)
__m128i sum = _mm_cvtsi32_si128(biases_[i]);
const auto row = reinterpret_cast<const __m128i*>(&weights_[offset]);