mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-25 03:26:24 +08:00
Optimize trainer clipped relu backpropagate.
This commit is contained in:
@@ -69,6 +69,55 @@ namespace Eval::NNUE {
|
||||
const LearnFloatType* gradients,
|
||||
LearnFloatType learning_rate) {
|
||||
|
||||
#if defined (USE_SSE2)
|
||||
|
||||
{
|
||||
static_assert(kOutputDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time");
|
||||
|
||||
const __m128 kZero4 = _mm_set1_ps(+kZero);
|
||||
const __m128 kOne4 = _mm_set1_ps(+kOne);
|
||||
|
||||
const IndexType total_size = batch_size_ * kOutputDimensions;
|
||||
|
||||
for (IndexType i = 0; i < total_size; i += 16)
|
||||
{
|
||||
__m128 out0 = _mm_loadu_ps(&output_[i + 0]);
|
||||
__m128 out1 = _mm_loadu_ps(&output_[i + 4]);
|
||||
__m128 out2 = _mm_loadu_ps(&output_[i + 8]);
|
||||
__m128 out3 = _mm_loadu_ps(&output_[i + 12]);
|
||||
|
||||
__m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4));
|
||||
__m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4));
|
||||
__m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4));
|
||||
__m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4));
|
||||
|
||||
__m128 grad0 = _mm_loadu_ps(&gradients[i + 0]);
|
||||
__m128 grad1 = _mm_loadu_ps(&gradients[i + 4]);
|
||||
__m128 grad2 = _mm_loadu_ps(&gradients[i + 8]);
|
||||
__m128 grad3 = _mm_loadu_ps(&gradients[i + 12]);
|
||||
|
||||
grad0 = _mm_andnot_ps(clipped0, grad0);
|
||||
grad1 = _mm_andnot_ps(clipped1, grad1);
|
||||
grad2 = _mm_andnot_ps(clipped2, grad2);
|
||||
grad3 = _mm_andnot_ps(clipped3, grad3);
|
||||
|
||||
_mm_storeu_ps(&gradients_[i + 0], grad0);
|
||||
_mm_storeu_ps(&gradients_[i + 4], grad1);
|
||||
_mm_storeu_ps(&gradients_[i + 8], grad2);
|
||||
_mm_storeu_ps(&gradients_[i + 12], grad3);
|
||||
|
||||
const int clipped_mask =
|
||||
(_mm_movemask_ps(clipped0) << 0)
|
||||
| (_mm_movemask_ps(clipped1) << 4)
|
||||
| (_mm_movemask_ps(clipped2) << 8)
|
||||
| (_mm_movemask_ps(clipped3) << 12);
|
||||
|
||||
num_clipped_ += popcount(clipped_mask);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
for (IndexType b = 0; b < batch_size_; ++b) {
|
||||
const IndexType batch_offset = kOutputDimensions * b;
|
||||
for (IndexType i = 0; i < kOutputDimensions; ++i) {
|
||||
@@ -78,6 +127,9 @@ namespace Eval::NNUE {
|
||||
num_clipped_ += clipped;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
num_total_ += batch_size_ * kOutputDimensions;
|
||||
|
||||
previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate);
|
||||
|
||||
Reference in New Issue
Block a user