diff --git a/src/nnue/trainer/trainer_feature_transformer.h b/src/nnue/trainer/trainer_feature_transformer.h index 8be584e8..c883b594 100644 --- a/src/nnue/trainer/trainer_feature_transformer.h +++ b/src/nnue/trainer/trainer_feature_transformer.h @@ -143,6 +143,119 @@ namespace Eval::NNUE { } } +#if defined (USE_SSE2) + + { + static_assert(kHalfDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time"); + + auto m128_hmin_ps = [](__m128 x3210) { + __m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2)); + __m128 min_x_x_13_20 = _mm_min_ps(x3210, x0032); + // a = [ # , # , min(x[1], x[3]) , min(x[2], x[0]) ] + __m128 min_x_x_20_13 = _mm_shuffle_ps(min_x_x_13_20, min_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1)); + return _mm_cvtss_f32(_mm_min_ps(min_x_x_13_20, min_x_x_20_13)); + }; + + auto m128_hmax_ps = [](__m128 x3210) { + __m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2)); + __m128 max_x_x_13_20 = _mm_max_ps(x3210, x0032); + // a = [ # , # , max(x[1], x[3]) , max(x[2], x[0]) ] + __m128 max_x_x_20_13 = _mm_shuffle_ps(max_x_x_13_20, max_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1)); + return _mm_cvtss_f32(_mm_max_ps(max_x_x_13_20, max_x_x_20_13)); + }; + + const int total_size = batch.size() * kOutputDimensions; + + const __m128 kZero4 = _mm_set1_ps(+kZero); + const __m128 kOne4 = _mm_set1_ps(+kOne); + + __m128 min_pre_activation0 = _mm_set1_ps(min_pre_activation_); + __m128 min_pre_activation1 = _mm_set1_ps(min_pre_activation_); + __m128 max_pre_activation0 = _mm_set1_ps(max_pre_activation_); + __m128 max_pre_activation1 = _mm_set1_ps(max_pre_activation_); + + for (int i = 0; i < total_size; i += 16) + { + __m128 out0 = _mm_loadu_ps(&output_[i + 0]); + __m128 out1 = _mm_loadu_ps(&output_[i + 4]); + __m128 out2 = _mm_loadu_ps(&output_[i + 8]); + __m128 out3 = _mm_loadu_ps(&output_[i + 12]); + + __m128 min01 = _mm_min_ps(out0, out1); + __m128 min23 = _mm_min_ps(out2, out3); + + __m128 max01 = _mm_max_ps(out0, out1); + __m128 max23 = _mm_max_ps(out2, out3); + + min_pre_activation0 = _mm_min_ps(min_pre_activation0, min01); + min_pre_activation1 = _mm_min_ps(min_pre_activation1, min23); + max_pre_activation0 = _mm_max_ps(max_pre_activation0, max01); + max_pre_activation1 = _mm_max_ps(max_pre_activation1, max23); + + out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); + out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); + out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2)); + out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3)); + + _mm_storeu_ps(&output_[i + 0], out0); + _mm_storeu_ps(&output_[i + 4], out1); + _mm_storeu_ps(&output_[i + 8], out2); + _mm_storeu_ps(&output_[i + 12], out3); + } + + min_pre_activation_ = m128_hmin_ps(_mm_min_ps(min_pre_activation0, min_pre_activation1)); + max_pre_activation_ = m128_hmax_ps(_mm_max_ps(max_pre_activation0, max_pre_activation1)); + + for (IndexType b = 0; b < batch.size(); ++b) + { + const IndexType batch_offset = kOutputDimensions * b; + + for (IndexType half = 0; half < 2; ++half) + { + const IndexType half_offset = batch_offset + half * kHalfDimensions; + for (IndexType i = 0; i < kHalfDimensions; i += 16) + { + const __m128 out0 = _mm_loadu_ps(&output_[i + 0 + half_offset]); + const __m128 out1 = _mm_loadu_ps(&output_[i + 4 + half_offset]); + const __m128 out2 = _mm_loadu_ps(&output_[i + 8 + half_offset]); + const __m128 out3 = _mm_loadu_ps(&output_[i + 12 + half_offset]); + + __m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]); + __m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]); + __m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]); + __m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]); + + __m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]); + __m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]); + __m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]); + __m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]); + + minact0 = _mm_min_ps(out0, minact0); + minact1 = _mm_min_ps(out1, minact1); + minact2 = _mm_min_ps(out2, minact2); + minact3 = _mm_min_ps(out3, minact3); + + maxact0 = _mm_max_ps(out0, maxact0); + maxact1 = _mm_max_ps(out1, maxact1); + maxact2 = _mm_max_ps(out2, maxact2); + maxact3 = _mm_max_ps(out3, maxact3); + + _mm_storeu_ps(&min_activations_[i + 0], minact0); + _mm_storeu_ps(&min_activations_[i + 4], minact1); + _mm_storeu_ps(&min_activations_[i + 8], minact2); + _mm_storeu_ps(&min_activations_[i + 12], minact3); + + _mm_storeu_ps(&max_activations_[i + 0], maxact0); + _mm_storeu_ps(&max_activations_[i + 4], maxact1); + _mm_storeu_ps(&max_activations_[i + 8], maxact2); + _mm_storeu_ps(&max_activations_[i + 12], maxact3); + } + } + } + } + +#else + // clipped ReLU for (IndexType b = 0; b < batch.size(); ++b) { const IndexType batch_offset = kOutputDimensions * b; @@ -157,6 +270,8 @@ namespace Eval::NNUE { } } +#endif + return output_.data(); }