Prepare clipped relu trainer.

This commit is contained in:
Tomasz Sobczyk
2020-11-22 20:16:49 +01:00
committed by nodchip
parent 774b023641
commit 401fc0fbab

View File

@@ -42,16 +42,31 @@ namespace Eval::NNUE {
previous_layer_trainer_->initialize(rng);
}
// forward propagation
const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector<Example>& batch) {
if (output_.size() < kOutputDimensions * batch.size()) {
output_.resize(kOutputDimensions * batch.size());
gradients_.resize(kInputDimensions * batch.size());
const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector<Example>& combined_batch)
{
if (output_.size() < kOutputDimensions * combined_batch.size()) {
output_.resize(kOutputDimensions * combined_batch.size());
gradients_.resize(kInputDimensions * combined_batch.size());
}
const auto input = previous_layer_trainer_->propagate(thread_pool, batch);
if (thread_states_.size() < thread_pool.size())
{
thread_states_.resize(thread_pool.size());
}
batch_size_ = static_cast<IndexType>(batch.size());
input_ = previous_layer_trainer_->step_start(thread_pool, combined_batch);
batch_size_ = static_cast<IndexType>(combined_batch.size());
return output_.data();
}
// forward propagation
void propagate(Thread& th, const uint64_t offset, const uint64_t count) {
auto& thread_state = thread_states_[th.thread_idx()];
previous_layer_trainer_->propagate(th, offset, count);
#if defined (USE_SSE2)
@@ -61,16 +76,16 @@ namespace Eval::NNUE {
const __m128 kZero4 = _mm_set1_ps(+kZero);
const __m128 kOne4 = _mm_set1_ps(+kOne);
for (IndexType b = 0; b < batch.size(); ++b)
for (IndexType b = offset; b < offset + count; ++b)
{
const IndexType batch_offset = kOutputDimensions * b;
for (IndexType i = 0; i < kOutputDimensions; i += 16)
{
__m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]);
__m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]);
__m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]);
__m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]);
__m128 out0 = _mm_loadu_ps(&input_[i + 0 + batch_offset]);
__m128 out1 = _mm_loadu_ps(&input_[i + 4 + batch_offset]);
__m128 out2 = _mm_loadu_ps(&input_[i + 8 + batch_offset]);
__m128 out3 = _mm_loadu_ps(&input_[i + 12 + batch_offset]);
out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0));
out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1));
@@ -82,15 +97,15 @@ namespace Eval::NNUE {
_mm_storeu_ps(&output_[i + 8 + batch_offset], out2);
_mm_storeu_ps(&output_[i + 12 + batch_offset], out3);
__m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]);
__m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]);
__m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]);
__m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]);
__m128 minact0 = _mm_loadu_ps(&thread_state.min_activations_[i + 0]);
__m128 minact1 = _mm_loadu_ps(&thread_state.min_activations_[i + 4]);
__m128 minact2 = _mm_loadu_ps(&thread_state.min_activations_[i + 8]);
__m128 minact3 = _mm_loadu_ps(&thread_state.min_activations_[i + 12]);
__m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]);
__m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]);
__m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]);
__m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]);
__m128 maxact0 = _mm_loadu_ps(&thread_state.max_activations_[i + 0]);
__m128 maxact1 = _mm_loadu_ps(&thread_state.max_activations_[i + 4]);
__m128 maxact2 = _mm_loadu_ps(&thread_state.max_activations_[i + 8]);
__m128 maxact3 = _mm_loadu_ps(&thread_state.max_activations_[i + 12]);
minact0 = _mm_min_ps(out0, minact0);
minact1 = _mm_min_ps(out1, minact1);
@@ -102,40 +117,41 @@ namespace Eval::NNUE {
maxact2 = _mm_max_ps(out2, maxact2);
maxact3 = _mm_max_ps(out3, maxact3);
_mm_storeu_ps(&min_activations_[i + 0], minact0);
_mm_storeu_ps(&min_activations_[i + 4], minact1);
_mm_storeu_ps(&min_activations_[i + 8], minact2);
_mm_storeu_ps(&min_activations_[i + 12], minact3);
_mm_storeu_ps(&thread_state.min_activations_[i + 0], minact0);
_mm_storeu_ps(&thread_state.min_activations_[i + 4], minact1);
_mm_storeu_ps(&thread_state.min_activations_[i + 8], minact2);
_mm_storeu_ps(&thread_state.min_activations_[i + 12], minact3);
_mm_storeu_ps(&max_activations_[i + 0], maxact0);
_mm_storeu_ps(&max_activations_[i + 4], maxact1);
_mm_storeu_ps(&max_activations_[i + 8], maxact2);
_mm_storeu_ps(&max_activations_[i + 12], maxact3);
_mm_storeu_ps(&thread_state.max_activations_[i + 0], maxact0);
_mm_storeu_ps(&thread_state.max_activations_[i + 4], maxact1);
_mm_storeu_ps(&thread_state.max_activations_[i + 8], maxact2);
_mm_storeu_ps(&thread_state.max_activations_[i + 12], maxact3);
}
}
}
#else
for (IndexType b = 0; b < batch_size_; ++b) {
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
for (IndexType i = 0; i < kOutputDimensions; ++i) {
const IndexType index = batch_offset + i;
output_[index] = std::max(+kZero, std::min(+kOne, input[index]));
min_activations_[i] = std::min(min_activations_[i], output_[index]);
max_activations_[i] = std::max(max_activations_[i], output_[index]);
output_[index] = std::max(+kZero, std::min(+kOne, input_[index]));
thread_state.min_activations_[i] = std::min(thread_state.min_activations_[i], output_[index]);
thread_state.max_activations_[i] = std::max(thread_state.max_activations_[i], output_[index]);
}
}
#endif
return output_.data();
}
// backpropagation
void backpropagate(ThreadPool& thread_pool,
void backpropagate(Thread& th,
const LearnFloatType* gradients,
LearnFloatType learning_rate) {
const uint64_t offset,
const uint64_t count) {
auto& thread_state = thread_states_[th.thread_idx()];
#if defined (USE_SSE2)
@@ -145,62 +161,78 @@ namespace Eval::NNUE {
const __m128 kZero4 = _mm_set1_ps(+kZero);
const __m128 kOne4 = _mm_set1_ps(+kOne);
const IndexType total_size = batch_size_ * kOutputDimensions;
for (IndexType i = 0; i < total_size; i += 16)
for (IndexType b = offset; b < offset + count; ++b)
{
__m128 out0 = _mm_loadu_ps(&output_[i + 0]);
__m128 out1 = _mm_loadu_ps(&output_[i + 4]);
__m128 out2 = _mm_loadu_ps(&output_[i + 8]);
__m128 out3 = _mm_loadu_ps(&output_[i + 12]);
const IndexType batch_offset = kOutputDimensions * b;
__m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4));
__m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4));
__m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4));
__m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4));
for (IndexType i = 0; i < kOutputDimensions; i += 16)
{
__m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]);
__m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]);
__m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]);
__m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]);
__m128 grad0 = _mm_loadu_ps(&gradients[i + 0]);
__m128 grad1 = _mm_loadu_ps(&gradients[i + 4]);
__m128 grad2 = _mm_loadu_ps(&gradients[i + 8]);
__m128 grad3 = _mm_loadu_ps(&gradients[i + 12]);
__m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4));
__m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4));
__m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4));
__m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4));
grad0 = _mm_andnot_ps(clipped0, grad0);
grad1 = _mm_andnot_ps(clipped1, grad1);
grad2 = _mm_andnot_ps(clipped2, grad2);
grad3 = _mm_andnot_ps(clipped3, grad3);
__m128 grad0 = _mm_loadu_ps(&gradients[batch_offset + i + 0]);
__m128 grad1 = _mm_loadu_ps(&gradients[batch_offset + i + 4]);
__m128 grad2 = _mm_loadu_ps(&gradients[batch_offset + i + 8]);
__m128 grad3 = _mm_loadu_ps(&gradients[batch_offset + i + 12]);
_mm_storeu_ps(&gradients_[i + 0], grad0);
_mm_storeu_ps(&gradients_[i + 4], grad1);
_mm_storeu_ps(&gradients_[i + 8], grad2);
_mm_storeu_ps(&gradients_[i + 12], grad3);
grad0 = _mm_andnot_ps(clipped0, grad0);
grad1 = _mm_andnot_ps(clipped1, grad1);
grad2 = _mm_andnot_ps(clipped2, grad2);
grad3 = _mm_andnot_ps(clipped3, grad3);
const int clipped_mask =
(_mm_movemask_ps(clipped0) << 0)
| (_mm_movemask_ps(clipped1) << 4)
| (_mm_movemask_ps(clipped2) << 8)
| (_mm_movemask_ps(clipped3) << 12);
_mm_storeu_ps(&gradients_[batch_offset + i + 0], grad0);
_mm_storeu_ps(&gradients_[batch_offset + i + 4], grad1);
_mm_storeu_ps(&gradients_[batch_offset + i + 8], grad2);
_mm_storeu_ps(&gradients_[batch_offset + i + 12], grad3);
num_clipped_ += popcount(clipped_mask);
const int clipped_mask =
(_mm_movemask_ps(clipped0) << 0)
| (_mm_movemask_ps(clipped1) << 4)
| (_mm_movemask_ps(clipped2) << 8)
| (_mm_movemask_ps(clipped3) << 12);
thread_state.num_clipped_ += popcount(clipped_mask);
}
}
}
#else
for (IndexType b = 0; b < batch_size_; ++b) {
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b;
for (IndexType i = 0; i < kOutputDimensions; ++i) {
const IndexType index = batch_offset + i;
const bool clipped = (output_[index] <= kZero) | (output_[index] >= kOne);
gradients_[index] = gradients[index] * !clipped;
num_clipped_ += clipped;
thread_state.num_clipped_ += clipped;
}
}
#endif
num_total_ += batch_size_ * kOutputDimensions;
thread_state.num_total_ += count * kOutputDimensions;
previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate);
previous_layer_trainer_->backpropagate(th, gradients_.data(), offset, count);
}
void reduce_thread_state()
{
for (IndexType i = 1; i < thread_states_.size(); ++i)
{
thread_states_[0] += thread_states_[i];
}
}
void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate)
{
previous_layer_trainer_->step_end(thread_pool, learning_rate);
}
private:
@@ -215,22 +247,21 @@ namespace Eval::NNUE {
}
void reset_stats() {
std::fill(std::begin(min_activations_), std::end(min_activations_),
std::numeric_limits<LearnFloatType>::max());
std::fill(std::begin(max_activations_), std::end(max_activations_),
std::numeric_limits<LearnFloatType>::lowest());
num_clipped_ = 0;
num_total_ = 0;
for(auto& state : thread_states_)
state.reset();
}
// Check if there are any problems with learning
void check_health() {
reduce_thread_state();
auto& main_thread_state = thread_states_[0];
const auto largest_min_activation = *std::max_element(
std::begin(min_activations_), std::end(min_activations_));
std::begin(main_thread_state.min_activations_), std::end(main_thread_state.min_activations_));
const auto smallest_max_activation = *std::min_element(
std::begin(max_activations_), std::end(max_activations_));
std::begin(main_thread_state.max_activations_), std::end(main_thread_state.max_activations_));
auto out = sync_region_cout.new_region();
@@ -243,7 +274,7 @@ namespace Eval::NNUE {
<< " , smallest max activation = " << smallest_max_activation
<< std::endl;
out << " - clipped " << static_cast<double>(num_clipped_) / num_total_ * 100.0 << "% of outputs"
out << " - clipped " << static_cast<double>(main_thread_state.num_clipped_) / main_thread_state.num_total_ * 100.0 << "% of outputs"
<< std::endl;
out.unlock();
@@ -262,9 +293,10 @@ namespace Eval::NNUE {
// number of samples in mini-batch
IndexType batch_size_;
IndexType num_clipped_;
IndexType num_total_;
const LearnFloatType* input_;
// Trainer of the previous layer
const std::shared_ptr<Trainer<PreviousLayer>> previous_layer_trainer_;
@@ -277,9 +309,44 @@ namespace Eval::NNUE {
// buffer for back propagation
std::vector<LearnFloatType, CacheLineAlignedAllocator<LearnFloatType>> gradients_;
// Health check statistics
LearnFloatType min_activations_[kOutputDimensions];
LearnFloatType max_activations_[kOutputDimensions];
struct alignas(kCacheLineSize) ThreadState
{
// Health check statistics
LearnFloatType min_activations_[kOutputDimensions];
LearnFloatType max_activations_[kOutputDimensions];
IndexType num_clipped_;
IndexType num_total_;
ThreadState() { reset(); }
ThreadState& operator+=(const ThreadState& other)
{
for (IndexType i = 0; i < kOutputDimensions; ++i)
{
min_activations_[i] = std::min(min_activations_[i], other.min_activations_[i]);
}
for (IndexType i = 0; i < kOutputDimensions; ++i)
{
max_activations_[i] = std::max(max_activations_[i], other.max_activations_[i]);
}
num_clipped_ += other.num_clipped_;
num_total_ += other.num_total_;
return *this;
}
void reset()
{
std::fill(std::begin(min_activations_), std::end(min_activations_), std::numeric_limits<float>::max());
std::fill(std::begin(max_activations_), std::end(max_activations_), std::numeric_limits<float>::lowest());
num_clipped_ = 0;
num_total_ = 0;
}
};
std::vector<ThreadState, CacheLineAlignedAllocator<ThreadState>> thread_states_;
};
} // namespace Eval::NNUE