diff --git a/src/nnue/trainer/trainer_feature_transformer.h b/src/nnue/trainer/trainer_feature_transformer.h index c883b594..77edfbde 100644 --- a/src/nnue/trainer/trainer_feature_transformer.h +++ b/src/nnue/trainer/trainer_feature_transformer.h @@ -285,6 +285,55 @@ namespace Eval::NNUE { const LearnFloatType local_learning_rate = learning_rate * learning_rate_scale_; +#if defined (USE_SSE2) + + { + static_assert(kHalfDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time"); + + const __m128 kZero4 = _mm_set1_ps(+kZero); + const __m128 kOne4 = _mm_set1_ps(+kOne); + + const IndexType total_size = batch_->size() * kOutputDimensions; + + for (IndexType i = 0; i < total_size; i += 16) + { + __m128 out0 = _mm_loadu_ps(&output_[i + 0]); + __m128 out1 = _mm_loadu_ps(&output_[i + 4]); + __m128 out2 = _mm_loadu_ps(&output_[i + 8]); + __m128 out3 = _mm_loadu_ps(&output_[i + 12]); + + __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); + __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); + __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); + __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); + + __m128 grad0 = _mm_loadu_ps(&gradients[i + 0]); + __m128 grad1 = _mm_loadu_ps(&gradients[i + 4]); + __m128 grad2 = _mm_loadu_ps(&gradients[i + 8]); + __m128 grad3 = _mm_loadu_ps(&gradients[i + 12]); + + grad0 = _mm_andnot_ps(clipped0, grad0); + grad1 = _mm_andnot_ps(clipped1, grad1); + grad2 = _mm_andnot_ps(clipped2, grad2); + grad3 = _mm_andnot_ps(clipped3, grad3); + + _mm_storeu_ps(&gradients_[i + 0], grad0); + _mm_storeu_ps(&gradients_[i + 4], grad1); + _mm_storeu_ps(&gradients_[i + 8], grad2); + _mm_storeu_ps(&gradients_[i + 12], grad3); + + const int clipped_mask = + (_mm_movemask_ps(clipped0) << 0) + | (_mm_movemask_ps(clipped1) << 4) + | (_mm_movemask_ps(clipped2) << 8) + | (_mm_movemask_ps(clipped3) << 12); + + num_clipped_ += popcount(clipped_mask); + } + } + +#else + for (IndexType b = 0; b < batch_->size(); ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { @@ -294,6 +343,9 @@ namespace Eval::NNUE { num_clipped_ += clipped; } } + +#endif + num_total_ += batch_->size() * kOutputDimensions; // Since the weight matrix updates only the columns corresponding to the features that appeared in the input,