mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 11:36:51 +08:00
Optimize feature transformer clipped relu.
This commit is contained in:
@@ -143,6 +143,119 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined (USE_SSE2)
|
||||
|
||||
{
|
||||
static_assert(kHalfDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time");
|
||||
|
||||
auto m128_hmin_ps = [](__m128 x3210) {
|
||||
__m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2));
|
||||
__m128 min_x_x_13_20 = _mm_min_ps(x3210, x0032);
|
||||
// a = [ # , # , min(x[1], x[3]) , min(x[2], x[0]) ]
|
||||
__m128 min_x_x_20_13 = _mm_shuffle_ps(min_x_x_13_20, min_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1));
|
||||
return _mm_cvtss_f32(_mm_min_ps(min_x_x_13_20, min_x_x_20_13));
|
||||
};
|
||||
|
||||
auto m128_hmax_ps = [](__m128 x3210) {
|
||||
__m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2));
|
||||
__m128 max_x_x_13_20 = _mm_max_ps(x3210, x0032);
|
||||
// a = [ # , # , max(x[1], x[3]) , max(x[2], x[0]) ]
|
||||
__m128 max_x_x_20_13 = _mm_shuffle_ps(max_x_x_13_20, max_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1));
|
||||
return _mm_cvtss_f32(_mm_max_ps(max_x_x_13_20, max_x_x_20_13));
|
||||
};
|
||||
|
||||
const int total_size = batch.size() * kOutputDimensions;
|
||||
|
||||
const __m128 kZero4 = _mm_set1_ps(+kZero);
|
||||
const __m128 kOne4 = _mm_set1_ps(+kOne);
|
||||
|
||||
__m128 min_pre_activation0 = _mm_set1_ps(min_pre_activation_);
|
||||
__m128 min_pre_activation1 = _mm_set1_ps(min_pre_activation_);
|
||||
__m128 max_pre_activation0 = _mm_set1_ps(max_pre_activation_);
|
||||
__m128 max_pre_activation1 = _mm_set1_ps(max_pre_activation_);
|
||||
|
||||
for (int i = 0; i < total_size; i += 16)
|
||||
{
|
||||
__m128 out0 = _mm_loadu_ps(&output_[i + 0]);
|
||||
__m128 out1 = _mm_loadu_ps(&output_[i + 4]);
|
||||
__m128 out2 = _mm_loadu_ps(&output_[i + 8]);
|
||||
__m128 out3 = _mm_loadu_ps(&output_[i + 12]);
|
||||
|
||||
__m128 min01 = _mm_min_ps(out0, out1);
|
||||
__m128 min23 = _mm_min_ps(out2, out3);
|
||||
|
||||
__m128 max01 = _mm_max_ps(out0, out1);
|
||||
__m128 max23 = _mm_max_ps(out2, out3);
|
||||
|
||||
min_pre_activation0 = _mm_min_ps(min_pre_activation0, min01);
|
||||
min_pre_activation1 = _mm_min_ps(min_pre_activation1, min23);
|
||||
max_pre_activation0 = _mm_max_ps(max_pre_activation0, max01);
|
||||
max_pre_activation1 = _mm_max_ps(max_pre_activation1, max23);
|
||||
|
||||
out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0));
|
||||
out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1));
|
||||
out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2));
|
||||
out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3));
|
||||
|
||||
_mm_storeu_ps(&output_[i + 0], out0);
|
||||
_mm_storeu_ps(&output_[i + 4], out1);
|
||||
_mm_storeu_ps(&output_[i + 8], out2);
|
||||
_mm_storeu_ps(&output_[i + 12], out3);
|
||||
}
|
||||
|
||||
min_pre_activation_ = m128_hmin_ps(_mm_min_ps(min_pre_activation0, min_pre_activation1));
|
||||
max_pre_activation_ = m128_hmax_ps(_mm_max_ps(max_pre_activation0, max_pre_activation1));
|
||||
|
||||
for (IndexType b = 0; b < batch.size(); ++b)
|
||||
{
|
||||
const IndexType batch_offset = kOutputDimensions * b;
|
||||
|
||||
for (IndexType half = 0; half < 2; ++half)
|
||||
{
|
||||
const IndexType half_offset = batch_offset + half * kHalfDimensions;
|
||||
for (IndexType i = 0; i < kHalfDimensions; i += 16)
|
||||
{
|
||||
const __m128 out0 = _mm_loadu_ps(&output_[i + 0 + half_offset]);
|
||||
const __m128 out1 = _mm_loadu_ps(&output_[i + 4 + half_offset]);
|
||||
const __m128 out2 = _mm_loadu_ps(&output_[i + 8 + half_offset]);
|
||||
const __m128 out3 = _mm_loadu_ps(&output_[i + 12 + half_offset]);
|
||||
|
||||
__m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]);
|
||||
__m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]);
|
||||
__m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]);
|
||||
__m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]);
|
||||
|
||||
__m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]);
|
||||
__m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]);
|
||||
__m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]);
|
||||
__m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]);
|
||||
|
||||
minact0 = _mm_min_ps(out0, minact0);
|
||||
minact1 = _mm_min_ps(out1, minact1);
|
||||
minact2 = _mm_min_ps(out2, minact2);
|
||||
minact3 = _mm_min_ps(out3, minact3);
|
||||
|
||||
maxact0 = _mm_max_ps(out0, maxact0);
|
||||
maxact1 = _mm_max_ps(out1, maxact1);
|
||||
maxact2 = _mm_max_ps(out2, maxact2);
|
||||
maxact3 = _mm_max_ps(out3, maxact3);
|
||||
|
||||
_mm_storeu_ps(&min_activations_[i + 0], minact0);
|
||||
_mm_storeu_ps(&min_activations_[i + 4], minact1);
|
||||
_mm_storeu_ps(&min_activations_[i + 8], minact2);
|
||||
_mm_storeu_ps(&min_activations_[i + 12], minact3);
|
||||
|
||||
_mm_storeu_ps(&max_activations_[i + 0], maxact0);
|
||||
_mm_storeu_ps(&max_activations_[i + 4], maxact1);
|
||||
_mm_storeu_ps(&max_activations_[i + 8], maxact2);
|
||||
_mm_storeu_ps(&max_activations_[i + 12], maxact3);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// clipped ReLU
|
||||
for (IndexType b = 0; b < batch.size(); ++b) {
|
||||
const IndexType batch_offset = kOutputDimensions * b;
|
||||
@@ -157,6 +270,8 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return output_.data();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user