mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Optimize trainer clipped relu propagate
This commit is contained in:
@@ -50,7 +50,73 @@ namespace Eval::NNUE {
|
||||
}
|
||||
|
||||
const auto input = previous_layer_trainer_->propagate(thread_pool, batch);
|
||||
|
||||
batch_size_ = static_cast<IndexType>(batch.size());
|
||||
|
||||
#if defined (USE_SSE2)
|
||||
|
||||
{
|
||||
static_assert(kOutputDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time");
|
||||
|
||||
const __m128 kZero4 = _mm_set1_ps(+kZero);
|
||||
const __m128 kOne4 = _mm_set1_ps(+kOne);
|
||||
|
||||
for (IndexType b = 0; b < batch.size(); ++b)
|
||||
{
|
||||
const IndexType batch_offset = kOutputDimensions * b;
|
||||
|
||||
for (IndexType i = 0; i < kOutputDimensions; i += 16)
|
||||
{
|
||||
__m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]);
|
||||
__m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]);
|
||||
__m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]);
|
||||
__m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]);
|
||||
|
||||
out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0));
|
||||
out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1));
|
||||
out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2));
|
||||
out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3));
|
||||
|
||||
_mm_storeu_ps(&output_[i + 0 + batch_offset], out0);
|
||||
_mm_storeu_ps(&output_[i + 4 + batch_offset], out1);
|
||||
_mm_storeu_ps(&output_[i + 8 + batch_offset], out2);
|
||||
_mm_storeu_ps(&output_[i + 12 + batch_offset], out3);
|
||||
|
||||
__m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]);
|
||||
__m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]);
|
||||
__m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]);
|
||||
__m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]);
|
||||
|
||||
__m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]);
|
||||
__m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]);
|
||||
__m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]);
|
||||
__m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]);
|
||||
|
||||
minact0 = _mm_min_ps(out0, minact0);
|
||||
minact1 = _mm_min_ps(out1, minact1);
|
||||
minact2 = _mm_min_ps(out2, minact2);
|
||||
minact3 = _mm_min_ps(out3, minact3);
|
||||
|
||||
maxact0 = _mm_max_ps(out0, maxact0);
|
||||
maxact1 = _mm_max_ps(out1, maxact1);
|
||||
maxact2 = _mm_max_ps(out2, maxact2);
|
||||
maxact3 = _mm_max_ps(out3, maxact3);
|
||||
|
||||
_mm_storeu_ps(&min_activations_[i + 0], minact0);
|
||||
_mm_storeu_ps(&min_activations_[i + 4], minact1);
|
||||
_mm_storeu_ps(&min_activations_[i + 8], minact2);
|
||||
_mm_storeu_ps(&min_activations_[i + 12], minact3);
|
||||
|
||||
_mm_storeu_ps(&max_activations_[i + 0], maxact0);
|
||||
_mm_storeu_ps(&max_activations_[i + 4], maxact1);
|
||||
_mm_storeu_ps(&max_activations_[i + 8], maxact2);
|
||||
_mm_storeu_ps(&max_activations_[i + 12], maxact3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
for (IndexType b = 0; b < batch_size_; ++b) {
|
||||
const IndexType batch_offset = kOutputDimensions * b;
|
||||
for (IndexType i = 0; i < kOutputDimensions; ++i) {
|
||||
@@ -61,6 +127,8 @@ namespace Eval::NNUE {
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
return output_.data();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user