From 4ea8572b6d8fdbd092c94954c78a6b0a47289083 Mon Sep 17 00:00:00 2001 From: Tomasz Sobczyk Date: Sun, 22 Nov 2020 18:27:51 +0100 Subject: [PATCH] Add single threaded sgemm. --- src/extra/stockfish_blas.cpp | 290 +++++++++++++++++++++++++++++++++++ src/extra/stockfish_blas.h | 10 ++ 2 files changed, 300 insertions(+) diff --git a/src/extra/stockfish_blas.cpp b/src/extra/stockfish_blas.cpp index 0ba40b49..109a4b44 100644 --- a/src/extra/stockfish_blas.cpp +++ b/src/extra/stockfish_blas.cpp @@ -546,6 +546,156 @@ namespace Blas { ); thread_pool.wait_for_workers_finished(); +#endif + } + + void sgemm_row_major_transpose_right( + const int M, const int N, const int K, + const float alpha, + const float * SF_BLAS_RESTRICT A, const int lda, + const float * SF_BLAS_RESTRICT B, const int ldb, + const float beta, + float * SF_BLAS_RESTRICT C, const int ldc + ) + { + +#if defined(USE_SSE3) + + const __m128 alpha4 = _mm_set1_ps(alpha); + const __m128 beta4 = _mm_set1_ps(beta); + + for (int m = 0; m < M - 1; m += 2) + { + int n = 0; + for (; n < N - 3; n += 4) + { + // mn + __m128 sum00 = _mm_setzero_ps(); + __m128 sum01 = _mm_setzero_ps(); + __m128 sum02 = _mm_setzero_ps(); + __m128 sum03 = _mm_setzero_ps(); + __m128 sum10 = _mm_setzero_ps(); + __m128 sum11 = _mm_setzero_ps(); + __m128 sum12 = _mm_setzero_ps(); + __m128 sum13 = _mm_setzero_ps(); + + // Horizontal sum of elements in sum[m][n] corresponds to + // the final element in the C. + + int k = 0; + for (; k < K - 3; k += 4) + { + const __m128 a0 = _mm_loadu_ps(&A[(m+0)*lda+k+0]); + const __m128 a1 = _mm_loadu_ps(&A[(m+1)*lda+k+0]); + + const __m128 b0 = _mm_loadu_ps(&B[(n+0)*ldb+k+0]); + const __m128 b1 = _mm_loadu_ps(&B[(n+1)*ldb+k+0]); + const __m128 b2 = _mm_loadu_ps(&B[(n+2)*ldb+k+0]); + const __m128 b3 = _mm_loadu_ps(&B[(n+3)*ldb+k+0]); + + sum00 = _mm_add_ps(sum00, _mm_mul_ps(a0, b0)); + sum01 = _mm_add_ps(sum01, _mm_mul_ps(a0, b1)); + sum02 = _mm_add_ps(sum02, _mm_mul_ps(a0, b2)); + sum03 = _mm_add_ps(sum03, _mm_mul_ps(a0, b3)); + sum10 = _mm_add_ps(sum10, _mm_mul_ps(a1, b0)); + sum11 = _mm_add_ps(sum11, _mm_mul_ps(a1, b1)); + sum12 = _mm_add_ps(sum12, _mm_mul_ps(a1, b2)); + sum13 = _mm_add_ps(sum13, _mm_mul_ps(a1, b3)); + } + + for(; k < K; k += 1) + { + const float a0 = A[(m+0)*lda+k+0]; + const float a1 = A[(m+1)*lda+k+0]; + + const float b0 = B[(n+0)*ldb+k+0]; + const float b1 = B[(n+1)*ldb+k+0]; + const float b2 = B[(n+2)*ldb+k+0]; + const float b3 = B[(n+3)*ldb+k+0]; + + // Since all will be summed vertically anyway we can + // just add to the first element. + // Other elements are left unmodified. + sum00 = _mm_add_ss(sum00, _mm_set_ss(a0 * b0)); + sum01 = _mm_add_ss(sum01, _mm_set_ss(a0 * b1)); + sum02 = _mm_add_ss(sum02, _mm_set_ss(a0 * b2)); + sum03 = _mm_add_ss(sum03, _mm_set_ss(a0 * b3)); + sum10 = _mm_add_ss(sum10, _mm_set_ss(a1 * b0)); + sum11 = _mm_add_ss(sum11, _mm_set_ss(a1 * b1)); + sum12 = _mm_add_ss(sum12, _mm_set_ss(a1 * b2)); + sum13 = _mm_add_ss(sum13, _mm_set_ss(a1 * b3)); + } + + __m128 s0 = m128_hadd_ps(sum00, sum01, sum02, sum03); + __m128 s1 = m128_hadd_ps(sum10, sum11, sum12, sum13); + s0 = _mm_mul_ps(s0, alpha4); + s1 = _mm_mul_ps(s1, alpha4); + + __m128 c0 = _mm_loadu_ps(&C[(m+0)*ldc+(n+0)]); + __m128 c1 = _mm_loadu_ps(&C[(m+1)*ldc+(n+0)]); + c0 = _mm_mul_ps(c0, beta4); + c1 = _mm_mul_ps(c1, beta4); + + c0 = _mm_add_ps(c0, s0); + c1 = _mm_add_ps(c1, s1); + + _mm_storeu_ps(&C[(m+0)*ldc+(n+0)], c0); + _mm_storeu_ps(&C[(m+1)*ldc+(n+0)], c1); + } + + for(; n < N; n += 1) + { + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int k = 0; k < K; ++k) + { + const float a0 = A[(m+0)*lda+k+0]; + const float a1 = A[(m+1)*lda+k+0]; + + const float b0 = B[(n+0)*ldb+k+0]; + + sum0 += a0 * b0; + sum1 += a1 * b0; + } + + C[(m+0)*ldc+(n+0)] = C[(m+0)*ldc+(n+0)] * beta + sum0 * alpha; + C[(m+1)*ldc+(n+0)] = C[(m+1)*ldc+(n+0)] * beta + sum1 * alpha; + } + } + + for (; m < M; m += 1) + { + for (int n = 0; n < N; n += 1) + { + float sum = 0.0f; + + for (int k = 0; k < K; k += 1) + { + sum += A[m*lda + k] * B[n*ldb + k]; + } + + C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha; + } + } + +#else + + for (int m = 0; m < M; m += 1) + { + for (int n = 0; n < N; n += 1) + { + float sum = 0.0f; + + for (int k = 0; k < K; k += 1) + { + sum += A[m*lda + k] * B[n*ldb + k]; + } + + C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha; + } + } + #endif } @@ -605,6 +755,35 @@ namespace Blas { ); } + void sgemm_row_major_transpose_none( + const int M, const int N, const int K, + const float alpha, + const float * SF_BLAS_RESTRICT A, const int lda, + const float * SF_BLAS_RESTRICT B, const int ldb, + const float beta, + float * SF_BLAS_RESTRICT C, const int ldc + ) + { + constexpr static int temporary_buffer_index = 1; + + auto B_tr = get_thread_local_temporary_storage(K * N, temporary_buffer_index); + + transpose( + K, N, + B, ldb, + B_tr, K + ); + + sgemm_row_major_transpose_right( + M, N, K, + alpha, + A, lda, + B_tr, K, + beta, + C, ldc + ); + } + void sgemm_row_major( ThreadPool& thread_pool, MatrixTranspose TransA, MatrixTranspose TransB, @@ -684,6 +863,80 @@ namespace Blas { } } + void sgemm_row_major( + MatrixTranspose TransA, MatrixTranspose TransB, + const int M, const int N, const int K, + const float alpha, + const float * SF_BLAS_RESTRICT A, const int lda, + const float * SF_BLAS_RESTRICT B, const int ldb, + const float beta, + float * SF_BLAS_RESTRICT C, const int ldc + ) + { + constexpr static int temporary_buffer_index = 0; + + if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::Trans) + { + auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index); + + transpose( + K, M, + A, lda, + A_tr, K + ); + + sgemm_row_major_transpose_right( + M, N, K, + alpha, + A_tr, K, + B, ldb, + beta, + C, ldc + ); + } + else if (TransA == MatrixTranspose::NoTrans && TransB == MatrixTranspose::Trans) + { + sgemm_row_major_transpose_right( + M, N, K, + alpha, + A, lda, + B, ldb, + beta, + C, ldc + ); + } + else if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::NoTrans) + { + auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index); + + transpose( + K, M, + A, lda, + A_tr, K + ); + + sgemm_row_major_transpose_none( + M, N, K, + alpha, + A_tr, K, + B, ldb, + beta, + C, ldc + ); + } + else // no transpositions + { + sgemm_row_major_transpose_none( + M, N, K, + alpha, + A, lda, + B, ldb, + beta, + C, ldc + ); + } + } + void sgemm( ThreadPool& thread_pool, MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB, @@ -723,6 +976,43 @@ namespace Blas { } } + + void sgemm( + MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB, + const int M, const int N, const int K, + const float alpha, + const float * SF_BLAS_RESTRICT A, const int lda, + const float * SF_BLAS_RESTRICT B, const int ldb, + const float beta, + float * SF_BLAS_RESTRICT C, const int ldc + ) + { + if (layout == MatrixLayout::RowMajor) + { + sgemm_row_major( + TransA, TransB, + M, N, K, + alpha, + A, lda, + B, ldb, + beta, + C, ldc + ); + } + else + { + sgemm_row_major( + TransB, TransA, + N, M, K, + alpha, + B, ldb, + A, lda, + beta, + C, ldc + ); + } + } + std::vector generate_random_matrix(int rows, int cols) { std::vector m(rows * cols); diff --git a/src/extra/stockfish_blas.h b/src/extra/stockfish_blas.h index 65da7e99..f551bbf2 100644 --- a/src/extra/stockfish_blas.h +++ b/src/extra/stockfish_blas.h @@ -118,6 +118,16 @@ namespace Blas { float * SF_BLAS_RESTRICT C, const int ldc ); + void sgemm( + MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB, + const int M, const int N, const int K, + const float alpha, + const float * SF_BLAS_RESTRICT A, const int lda, + const float * SF_BLAS_RESTRICT B, const int ldb, + const float beta, + float * SF_BLAS_RESTRICT C, const int ldc + ); + void test( ThreadPool& thread_pool );