diff --git a/src/nnue/trainer/trainer_clipped_relu.h b/src/nnue/trainer/trainer_clipped_relu.h index 124671ed..e4bcecaf 100644 --- a/src/nnue/trainer/trainer_clipped_relu.h +++ b/src/nnue/trainer/trainer_clipped_relu.h @@ -42,16 +42,31 @@ namespace Eval::NNUE { previous_layer_trainer_->initialize(rng); } - // forward propagation - const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector& batch) { - if (output_.size() < kOutputDimensions * batch.size()) { - output_.resize(kOutputDimensions * batch.size()); - gradients_.resize(kInputDimensions * batch.size()); + const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector& combined_batch) + { + if (output_.size() < kOutputDimensions * combined_batch.size()) { + output_.resize(kOutputDimensions * combined_batch.size()); + gradients_.resize(kInputDimensions * combined_batch.size()); } - const auto input = previous_layer_trainer_->propagate(thread_pool, batch); + if (thread_states_.size() < thread_pool.size()) + { + thread_states_.resize(thread_pool.size()); + } - batch_size_ = static_cast(batch.size()); + input_ = previous_layer_trainer_->step_start(thread_pool, combined_batch); + + batch_size_ = static_cast(combined_batch.size()); + + return output_.data(); + } + + // forward propagation + void propagate(Thread& th, const uint64_t offset, const uint64_t count) { + + auto& thread_state = thread_states_[th.thread_idx()]; + + previous_layer_trainer_->propagate(th, offset, count); #if defined (USE_SSE2) @@ -61,16 +76,16 @@ namespace Eval::NNUE { const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); - for (IndexType b = 0; b < batch.size(); ++b) + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; i += 16) { - __m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]); - __m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]); - __m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]); - __m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]); + __m128 out0 = _mm_loadu_ps(&input_[i + 0 + batch_offset]); + __m128 out1 = _mm_loadu_ps(&input_[i + 4 + batch_offset]); + __m128 out2 = _mm_loadu_ps(&input_[i + 8 + batch_offset]); + __m128 out3 = _mm_loadu_ps(&input_[i + 12 + batch_offset]); out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); @@ -82,15 +97,15 @@ namespace Eval::NNUE { _mm_storeu_ps(&output_[i + 8 + batch_offset], out2); _mm_storeu_ps(&output_[i + 12 + batch_offset], out3); - __m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]); - __m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]); - __m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]); - __m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]); + __m128 minact0 = _mm_loadu_ps(&thread_state.min_activations_[i + 0]); + __m128 minact1 = _mm_loadu_ps(&thread_state.min_activations_[i + 4]); + __m128 minact2 = _mm_loadu_ps(&thread_state.min_activations_[i + 8]); + __m128 minact3 = _mm_loadu_ps(&thread_state.min_activations_[i + 12]); - __m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]); - __m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]); - __m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]); - __m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]); + __m128 maxact0 = _mm_loadu_ps(&thread_state.max_activations_[i + 0]); + __m128 maxact1 = _mm_loadu_ps(&thread_state.max_activations_[i + 4]); + __m128 maxact2 = _mm_loadu_ps(&thread_state.max_activations_[i + 8]); + __m128 maxact3 = _mm_loadu_ps(&thread_state.max_activations_[i + 12]); minact0 = _mm_min_ps(out0, minact0); minact1 = _mm_min_ps(out1, minact1); @@ -102,40 +117,41 @@ namespace Eval::NNUE { maxact2 = _mm_max_ps(out2, maxact2); maxact3 = _mm_max_ps(out3, maxact3); - _mm_storeu_ps(&min_activations_[i + 0], minact0); - _mm_storeu_ps(&min_activations_[i + 4], minact1); - _mm_storeu_ps(&min_activations_[i + 8], minact2); - _mm_storeu_ps(&min_activations_[i + 12], minact3); + _mm_storeu_ps(&thread_state.min_activations_[i + 0], minact0); + _mm_storeu_ps(&thread_state.min_activations_[i + 4], minact1); + _mm_storeu_ps(&thread_state.min_activations_[i + 8], minact2); + _mm_storeu_ps(&thread_state.min_activations_[i + 12], minact3); - _mm_storeu_ps(&max_activations_[i + 0], maxact0); - _mm_storeu_ps(&max_activations_[i + 4], maxact1); - _mm_storeu_ps(&max_activations_[i + 8], maxact2); - _mm_storeu_ps(&max_activations_[i + 12], maxact3); + _mm_storeu_ps(&thread_state.max_activations_[i + 0], maxact0); + _mm_storeu_ps(&thread_state.max_activations_[i + 4], maxact1); + _mm_storeu_ps(&thread_state.max_activations_[i + 8], maxact2); + _mm_storeu_ps(&thread_state.max_activations_[i + 12], maxact3); } } } #else - for (IndexType b = 0; b < batch_size_; ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; - output_[index] = std::max(+kZero, std::min(+kOne, input[index])); - min_activations_[i] = std::min(min_activations_[i], output_[index]); - max_activations_[i] = std::max(max_activations_[i], output_[index]); + output_[index] = std::max(+kZero, std::min(+kOne, input_[index])); + thread_state.min_activations_[i] = std::min(thread_state.min_activations_[i], output_[index]); + thread_state.max_activations_[i] = std::max(thread_state.max_activations_[i], output_[index]); } } #endif - - return output_.data(); } // backpropagation - void backpropagate(ThreadPool& thread_pool, + void backpropagate(Thread& th, const LearnFloatType* gradients, - LearnFloatType learning_rate) { + const uint64_t offset, + const uint64_t count) { + + auto& thread_state = thread_states_[th.thread_idx()]; #if defined (USE_SSE2) @@ -145,62 +161,78 @@ namespace Eval::NNUE { const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); - const IndexType total_size = batch_size_ * kOutputDimensions; - - for (IndexType i = 0; i < total_size; i += 16) + for (IndexType b = offset; b < offset + count; ++b) { - __m128 out0 = _mm_loadu_ps(&output_[i + 0]); - __m128 out1 = _mm_loadu_ps(&output_[i + 4]); - __m128 out2 = _mm_loadu_ps(&output_[i + 8]); - __m128 out3 = _mm_loadu_ps(&output_[i + 12]); + const IndexType batch_offset = kOutputDimensions * b; - __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); - __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); - __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); - __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); + for (IndexType i = 0; i < kOutputDimensions; i += 16) + { + __m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]); + __m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]); + __m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]); + __m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]); - __m128 grad0 = _mm_loadu_ps(&gradients[i + 0]); - __m128 grad1 = _mm_loadu_ps(&gradients[i + 4]); - __m128 grad2 = _mm_loadu_ps(&gradients[i + 8]); - __m128 grad3 = _mm_loadu_ps(&gradients[i + 12]); + __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); + __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); + __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); + __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); - grad0 = _mm_andnot_ps(clipped0, grad0); - grad1 = _mm_andnot_ps(clipped1, grad1); - grad2 = _mm_andnot_ps(clipped2, grad2); - grad3 = _mm_andnot_ps(clipped3, grad3); + __m128 grad0 = _mm_loadu_ps(&gradients[batch_offset + i + 0]); + __m128 grad1 = _mm_loadu_ps(&gradients[batch_offset + i + 4]); + __m128 grad2 = _mm_loadu_ps(&gradients[batch_offset + i + 8]); + __m128 grad3 = _mm_loadu_ps(&gradients[batch_offset + i + 12]); - _mm_storeu_ps(&gradients_[i + 0], grad0); - _mm_storeu_ps(&gradients_[i + 4], grad1); - _mm_storeu_ps(&gradients_[i + 8], grad2); - _mm_storeu_ps(&gradients_[i + 12], grad3); + grad0 = _mm_andnot_ps(clipped0, grad0); + grad1 = _mm_andnot_ps(clipped1, grad1); + grad2 = _mm_andnot_ps(clipped2, grad2); + grad3 = _mm_andnot_ps(clipped3, grad3); - const int clipped_mask = - (_mm_movemask_ps(clipped0) << 0) - | (_mm_movemask_ps(clipped1) << 4) - | (_mm_movemask_ps(clipped2) << 8) - | (_mm_movemask_ps(clipped3) << 12); + _mm_storeu_ps(&gradients_[batch_offset + i + 0], grad0); + _mm_storeu_ps(&gradients_[batch_offset + i + 4], grad1); + _mm_storeu_ps(&gradients_[batch_offset + i + 8], grad2); + _mm_storeu_ps(&gradients_[batch_offset + i + 12], grad3); - num_clipped_ += popcount(clipped_mask); + const int clipped_mask = + (_mm_movemask_ps(clipped0) << 0) + | (_mm_movemask_ps(clipped1) << 4) + | (_mm_movemask_ps(clipped2) << 8) + | (_mm_movemask_ps(clipped3) << 12); + + thread_state.num_clipped_ += popcount(clipped_mask); + } } } #else - for (IndexType b = 0; b < batch_size_; ++b) { + for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; const bool clipped = (output_[index] <= kZero) | (output_[index] >= kOne); gradients_[index] = gradients[index] * !clipped; - num_clipped_ += clipped; + thread_state.num_clipped_ += clipped; } } #endif - num_total_ += batch_size_ * kOutputDimensions; + thread_state.num_total_ += count * kOutputDimensions; - previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate); + previous_layer_trainer_->backpropagate(th, gradients_.data(), offset, count); + } + + void reduce_thread_state() + { + for (IndexType i = 1; i < thread_states_.size(); ++i) + { + thread_states_[0] += thread_states_[i]; + } + } + + void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) + { + previous_layer_trainer_->step_end(thread_pool, learning_rate); } private: @@ -215,22 +247,21 @@ namespace Eval::NNUE { } void reset_stats() { - std::fill(std::begin(min_activations_), std::end(min_activations_), - std::numeric_limits::max()); - std::fill(std::begin(max_activations_), std::end(max_activations_), - std::numeric_limits::lowest()); - - num_clipped_ = 0; - num_total_ = 0; + for(auto& state : thread_states_) + state.reset(); } // Check if there are any problems with learning void check_health() { + reduce_thread_state(); + + auto& main_thread_state = thread_states_[0]; + const auto largest_min_activation = *std::max_element( - std::begin(min_activations_), std::end(min_activations_)); + std::begin(main_thread_state.min_activations_), std::end(main_thread_state.min_activations_)); const auto smallest_max_activation = *std::min_element( - std::begin(max_activations_), std::end(max_activations_)); + std::begin(main_thread_state.max_activations_), std::end(main_thread_state.max_activations_)); auto out = sync_region_cout.new_region(); @@ -243,7 +274,7 @@ namespace Eval::NNUE { << " , smallest max activation = " << smallest_max_activation << std::endl; - out << " - clipped " << static_cast(num_clipped_) / num_total_ * 100.0 << "% of outputs" + out << " - clipped " << static_cast(main_thread_state.num_clipped_) / main_thread_state.num_total_ * 100.0 << "% of outputs" << std::endl; out.unlock(); @@ -262,9 +293,10 @@ namespace Eval::NNUE { // number of samples in mini-batch IndexType batch_size_; - IndexType num_clipped_; IndexType num_total_; + const LearnFloatType* input_; + // Trainer of the previous layer const std::shared_ptr> previous_layer_trainer_; @@ -277,9 +309,44 @@ namespace Eval::NNUE { // buffer for back propagation std::vector> gradients_; - // Health check statistics - LearnFloatType min_activations_[kOutputDimensions]; - LearnFloatType max_activations_[kOutputDimensions]; + struct alignas(kCacheLineSize) ThreadState + { + // Health check statistics + LearnFloatType min_activations_[kOutputDimensions]; + LearnFloatType max_activations_[kOutputDimensions]; + IndexType num_clipped_; + IndexType num_total_; + + ThreadState() { reset(); } + + ThreadState& operator+=(const ThreadState& other) + { + for (IndexType i = 0; i < kOutputDimensions; ++i) + { + min_activations_[i] = std::min(min_activations_[i], other.min_activations_[i]); + } + + for (IndexType i = 0; i < kOutputDimensions; ++i) + { + max_activations_[i] = std::max(max_activations_[i], other.max_activations_[i]); + } + + num_clipped_ += other.num_clipped_; + num_total_ += other.num_total_; + + return *this; + } + + void reset() + { + std::fill(std::begin(min_activations_), std::end(min_activations_), std::numeric_limits::max()); + std::fill(std::begin(max_activations_), std::end(max_activations_), std::numeric_limits::lowest()); + num_clipped_ = 0; + num_total_ = 0; + } + }; + + std::vector> thread_states_; }; } // namespace Eval::NNUE