Optimize trainer clipped relu propagate

This commit is contained in:
Tomasz Sobczyk
2020-10-27 18:57:47 +01:00
committed by nodchip
parent b5714c4084
commit db1b33d4ac

View File

@@ -50,7 +50,73 @@ namespace Eval::NNUE {
}
const auto input = previous_layer_trainer_->propagate(thread_pool, batch);
batch_size_ = static_cast<IndexType>(batch.size());
#if defined (USE_SSE2)
{
static_assert(kOutputDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time");
const __m128 kZero4 = _mm_set1_ps(+kZero);
const __m128 kOne4 = _mm_set1_ps(+kOne);
for (IndexType b = 0; b < batch.size(); ++b)
{
const IndexType batch_offset = kOutputDimensions * b;
for (IndexType i = 0; i < kOutputDimensions; i += 16)
{
__m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]);
__m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]);
__m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]);
__m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]);
out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0));
out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1));
out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2));
out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3));
_mm_storeu_ps(&output_[i + 0 + batch_offset], out0);
_mm_storeu_ps(&output_[i + 4 + batch_offset], out1);
_mm_storeu_ps(&output_[i + 8 + batch_offset], out2);
_mm_storeu_ps(&output_[i + 12 + batch_offset], out3);
__m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]);
__m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]);
__m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]);
__m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]);
__m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]);
__m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]);
__m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]);
__m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]);
minact0 = _mm_min_ps(out0, minact0);
minact1 = _mm_min_ps(out1, minact1);
minact2 = _mm_min_ps(out2, minact2);
minact3 = _mm_min_ps(out3, minact3);
maxact0 = _mm_max_ps(out0, maxact0);
maxact1 = _mm_max_ps(out1, maxact1);
maxact2 = _mm_max_ps(out2, maxact2);
maxact3 = _mm_max_ps(out3, maxact3);
_mm_storeu_ps(&min_activations_[i + 0], minact0);
_mm_storeu_ps(&min_activations_[i + 4], minact1);
_mm_storeu_ps(&min_activations_[i + 8], minact2);
_mm_storeu_ps(&min_activations_[i + 12], minact3);
_mm_storeu_ps(&max_activations_[i + 0], maxact0);
_mm_storeu_ps(&max_activations_[i + 4], maxact1);
_mm_storeu_ps(&max_activations_[i + 8], maxact2);
_mm_storeu_ps(&max_activations_[i + 12], maxact3);
}
}
}
#else
for (IndexType b = 0; b < batch_size_; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
for (IndexType i = 0; i < kOutputDimensions; ++i) {
@@ -61,6 +127,8 @@ namespace Eval::NNUE {
}
}
#endif
return output_.data();
}