From db1b33d4acfe02d4eb05eac5f810729da1d1ebf4 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Tue, 27 Oct 2020 18:57:47 +0100 Subject: [PATCH] Optimize trainer clipped relu propagate --- src/nnue/trainer/trainer_clipped_relu.h | 68 +++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/nnue/trainer/trainer_clipped_relu.h b/src/nnue/trainer/trainer_clipped_relu.h index dd6fc701..124671ed 100644 --- a/src/nnue/trainer/trainer_clipped_relu.h +++ b/src/nnue/trainer/trainer_clipped_relu.h @@ -50,7 +50,73 @@ namespace Eval::NNUE { } const auto input = previous_layer_trainer_->propagate(thread_pool, batch); + batch_size_ = static_cast(batch.size()); + +#if defined (USE_SSE2) + + { + static_assert(kOutputDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time"); + + const __m128 kZero4 = _mm_set1_ps(+kZero); + const __m128 kOne4 = _mm_set1_ps(+kOne); + + for (IndexType b = 0; b < batch.size(); ++b) + { + const IndexType batch_offset = kOutputDimensions * b; + + for (IndexType i = 0; i < kOutputDimensions; i += 16) + { + __m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]); + __m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]); + __m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]); + __m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]); + + out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); + out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); + out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2)); + out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3)); + + _mm_storeu_ps(&output_[i + 0 + batch_offset], out0); + _mm_storeu_ps(&output_[i + 4 + batch_offset], out1); + _mm_storeu_ps(&output_[i + 8 + batch_offset], out2); + _mm_storeu_ps(&output_[i + 12 + batch_offset], out3); + + __m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]); + __m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]); + __m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]); + __m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]); + + __m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]); + __m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]); + __m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]); + __m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]); + + minact0 = _mm_min_ps(out0, minact0); + minact1 = _mm_min_ps(out1, minact1); + minact2 = _mm_min_ps(out2, minact2); + minact3 = _mm_min_ps(out3, minact3); + + maxact0 = _mm_max_ps(out0, maxact0); + maxact1 = _mm_max_ps(out1, maxact1); + maxact2 = _mm_max_ps(out2, maxact2); + maxact3 = _mm_max_ps(out3, maxact3); + + _mm_storeu_ps(&min_activations_[i + 0], minact0); + _mm_storeu_ps(&min_activations_[i + 4], minact1); + _mm_storeu_ps(&min_activations_[i + 8], minact2); + _mm_storeu_ps(&min_activations_[i + 12], minact3); + + _mm_storeu_ps(&max_activations_[i + 0], maxact0); + _mm_storeu_ps(&max_activations_[i + 4], maxact1); + _mm_storeu_ps(&max_activations_[i + 8], maxact2); + _mm_storeu_ps(&max_activations_[i + 12], maxact3); + } + } + } + +#else + for (IndexType b = 0; b < batch_size_; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { @@ -61,6 +127,8 @@ namespace Eval::NNUE { } } +#endif + return output_.data(); }