mirror of
https://github.com/HChaZZY/Stockfish.git
synced 2025-12-24 19:16:49 +08:00
Add single threaded sgemm.
This commit is contained in:
@@ -546,6 +546,156 @@ namespace Blas {
|
||||
);
|
||||
thread_pool.wait_for_workers_finished();
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void sgemm_row_major_transpose_right(
|
||||
const int M, const int N, const int K,
|
||||
const float alpha,
|
||||
const float * SF_BLAS_RESTRICT A, const int lda,
|
||||
const float * SF_BLAS_RESTRICT B, const int ldb,
|
||||
const float beta,
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
)
|
||||
{
|
||||
|
||||
#if defined(USE_SSE3)
|
||||
|
||||
const __m128 alpha4 = _mm_set1_ps(alpha);
|
||||
const __m128 beta4 = _mm_set1_ps(beta);
|
||||
|
||||
for (int m = 0; m < M - 1; m += 2)
|
||||
{
|
||||
int n = 0;
|
||||
for (; n < N - 3; n += 4)
|
||||
{
|
||||
// mn
|
||||
__m128 sum00 = _mm_setzero_ps();
|
||||
__m128 sum01 = _mm_setzero_ps();
|
||||
__m128 sum02 = _mm_setzero_ps();
|
||||
__m128 sum03 = _mm_setzero_ps();
|
||||
__m128 sum10 = _mm_setzero_ps();
|
||||
__m128 sum11 = _mm_setzero_ps();
|
||||
__m128 sum12 = _mm_setzero_ps();
|
||||
__m128 sum13 = _mm_setzero_ps();
|
||||
|
||||
// Horizontal sum of elements in sum[m][n] corresponds to
|
||||
// the final element in the C.
|
||||
|
||||
int k = 0;
|
||||
for (; k < K - 3; k += 4)
|
||||
{
|
||||
const __m128 a0 = _mm_loadu_ps(&A[(m+0)*lda+k+0]);
|
||||
const __m128 a1 = _mm_loadu_ps(&A[(m+1)*lda+k+0]);
|
||||
|
||||
const __m128 b0 = _mm_loadu_ps(&B[(n+0)*ldb+k+0]);
|
||||
const __m128 b1 = _mm_loadu_ps(&B[(n+1)*ldb+k+0]);
|
||||
const __m128 b2 = _mm_loadu_ps(&B[(n+2)*ldb+k+0]);
|
||||
const __m128 b3 = _mm_loadu_ps(&B[(n+3)*ldb+k+0]);
|
||||
|
||||
sum00 = _mm_add_ps(sum00, _mm_mul_ps(a0, b0));
|
||||
sum01 = _mm_add_ps(sum01, _mm_mul_ps(a0, b1));
|
||||
sum02 = _mm_add_ps(sum02, _mm_mul_ps(a0, b2));
|
||||
sum03 = _mm_add_ps(sum03, _mm_mul_ps(a0, b3));
|
||||
sum10 = _mm_add_ps(sum10, _mm_mul_ps(a1, b0));
|
||||
sum11 = _mm_add_ps(sum11, _mm_mul_ps(a1, b1));
|
||||
sum12 = _mm_add_ps(sum12, _mm_mul_ps(a1, b2));
|
||||
sum13 = _mm_add_ps(sum13, _mm_mul_ps(a1, b3));
|
||||
}
|
||||
|
||||
for(; k < K; k += 1)
|
||||
{
|
||||
const float a0 = A[(m+0)*lda+k+0];
|
||||
const float a1 = A[(m+1)*lda+k+0];
|
||||
|
||||
const float b0 = B[(n+0)*ldb+k+0];
|
||||
const float b1 = B[(n+1)*ldb+k+0];
|
||||
const float b2 = B[(n+2)*ldb+k+0];
|
||||
const float b3 = B[(n+3)*ldb+k+0];
|
||||
|
||||
// Since all will be summed vertically anyway we can
|
||||
// just add to the first element.
|
||||
// Other elements are left unmodified.
|
||||
sum00 = _mm_add_ss(sum00, _mm_set_ss(a0 * b0));
|
||||
sum01 = _mm_add_ss(sum01, _mm_set_ss(a0 * b1));
|
||||
sum02 = _mm_add_ss(sum02, _mm_set_ss(a0 * b2));
|
||||
sum03 = _mm_add_ss(sum03, _mm_set_ss(a0 * b3));
|
||||
sum10 = _mm_add_ss(sum10, _mm_set_ss(a1 * b0));
|
||||
sum11 = _mm_add_ss(sum11, _mm_set_ss(a1 * b1));
|
||||
sum12 = _mm_add_ss(sum12, _mm_set_ss(a1 * b2));
|
||||
sum13 = _mm_add_ss(sum13, _mm_set_ss(a1 * b3));
|
||||
}
|
||||
|
||||
__m128 s0 = m128_hadd_ps(sum00, sum01, sum02, sum03);
|
||||
__m128 s1 = m128_hadd_ps(sum10, sum11, sum12, sum13);
|
||||
s0 = _mm_mul_ps(s0, alpha4);
|
||||
s1 = _mm_mul_ps(s1, alpha4);
|
||||
|
||||
__m128 c0 = _mm_loadu_ps(&C[(m+0)*ldc+(n+0)]);
|
||||
__m128 c1 = _mm_loadu_ps(&C[(m+1)*ldc+(n+0)]);
|
||||
c0 = _mm_mul_ps(c0, beta4);
|
||||
c1 = _mm_mul_ps(c1, beta4);
|
||||
|
||||
c0 = _mm_add_ps(c0, s0);
|
||||
c1 = _mm_add_ps(c1, s1);
|
||||
|
||||
_mm_storeu_ps(&C[(m+0)*ldc+(n+0)], c0);
|
||||
_mm_storeu_ps(&C[(m+1)*ldc+(n+0)], c1);
|
||||
}
|
||||
|
||||
for(; n < N; n += 1)
|
||||
{
|
||||
float sum0 = 0.0f;
|
||||
float sum1 = 0.0f;
|
||||
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
const float a0 = A[(m+0)*lda+k+0];
|
||||
const float a1 = A[(m+1)*lda+k+0];
|
||||
|
||||
const float b0 = B[(n+0)*ldb+k+0];
|
||||
|
||||
sum0 += a0 * b0;
|
||||
sum1 += a1 * b0;
|
||||
}
|
||||
|
||||
C[(m+0)*ldc+(n+0)] = C[(m+0)*ldc+(n+0)] * beta + sum0 * alpha;
|
||||
C[(m+1)*ldc+(n+0)] = C[(m+1)*ldc+(n+0)] * beta + sum1 * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
for (; m < M; m += 1)
|
||||
{
|
||||
for (int n = 0; n < N; n += 1)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < K; k += 1)
|
||||
{
|
||||
sum += A[m*lda + k] * B[n*ldb + k];
|
||||
}
|
||||
|
||||
C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
for (int m = 0; m < M; m += 1)
|
||||
{
|
||||
for (int n = 0; n < N; n += 1)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int k = 0; k < K; k += 1)
|
||||
{
|
||||
sum += A[m*lda + k] * B[n*ldb + k];
|
||||
}
|
||||
|
||||
C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -605,6 +755,35 @@ namespace Blas {
|
||||
);
|
||||
}
|
||||
|
||||
void sgemm_row_major_transpose_none(
|
||||
const int M, const int N, const int K,
|
||||
const float alpha,
|
||||
const float * SF_BLAS_RESTRICT A, const int lda,
|
||||
const float * SF_BLAS_RESTRICT B, const int ldb,
|
||||
const float beta,
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
)
|
||||
{
|
||||
constexpr static int temporary_buffer_index = 1;
|
||||
|
||||
auto B_tr = get_thread_local_temporary_storage(K * N, temporary_buffer_index);
|
||||
|
||||
transpose(
|
||||
K, N,
|
||||
B, ldb,
|
||||
B_tr, K
|
||||
);
|
||||
|
||||
sgemm_row_major_transpose_right(
|
||||
M, N, K,
|
||||
alpha,
|
||||
A, lda,
|
||||
B_tr, K,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
|
||||
void sgemm_row_major(
|
||||
ThreadPool& thread_pool,
|
||||
MatrixTranspose TransA, MatrixTranspose TransB,
|
||||
@@ -684,6 +863,80 @@ namespace Blas {
|
||||
}
|
||||
}
|
||||
|
||||
void sgemm_row_major(
|
||||
MatrixTranspose TransA, MatrixTranspose TransB,
|
||||
const int M, const int N, const int K,
|
||||
const float alpha,
|
||||
const float * SF_BLAS_RESTRICT A, const int lda,
|
||||
const float * SF_BLAS_RESTRICT B, const int ldb,
|
||||
const float beta,
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
)
|
||||
{
|
||||
constexpr static int temporary_buffer_index = 0;
|
||||
|
||||
if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::Trans)
|
||||
{
|
||||
auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index);
|
||||
|
||||
transpose(
|
||||
K, M,
|
||||
A, lda,
|
||||
A_tr, K
|
||||
);
|
||||
|
||||
sgemm_row_major_transpose_right(
|
||||
M, N, K,
|
||||
alpha,
|
||||
A_tr, K,
|
||||
B, ldb,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
else if (TransA == MatrixTranspose::NoTrans && TransB == MatrixTranspose::Trans)
|
||||
{
|
||||
sgemm_row_major_transpose_right(
|
||||
M, N, K,
|
||||
alpha,
|
||||
A, lda,
|
||||
B, ldb,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
else if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::NoTrans)
|
||||
{
|
||||
auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index);
|
||||
|
||||
transpose(
|
||||
K, M,
|
||||
A, lda,
|
||||
A_tr, K
|
||||
);
|
||||
|
||||
sgemm_row_major_transpose_none(
|
||||
M, N, K,
|
||||
alpha,
|
||||
A_tr, K,
|
||||
B, ldb,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
else // no transpositions
|
||||
{
|
||||
sgemm_row_major_transpose_none(
|
||||
M, N, K,
|
||||
alpha,
|
||||
A, lda,
|
||||
B, ldb,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void sgemm(
|
||||
ThreadPool& thread_pool,
|
||||
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
|
||||
@@ -723,6 +976,43 @@ namespace Blas {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void sgemm(
|
||||
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
|
||||
const int M, const int N, const int K,
|
||||
const float alpha,
|
||||
const float * SF_BLAS_RESTRICT A, const int lda,
|
||||
const float * SF_BLAS_RESTRICT B, const int ldb,
|
||||
const float beta,
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
)
|
||||
{
|
||||
if (layout == MatrixLayout::RowMajor)
|
||||
{
|
||||
sgemm_row_major(
|
||||
TransA, TransB,
|
||||
M, N, K,
|
||||
alpha,
|
||||
A, lda,
|
||||
B, ldb,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
sgemm_row_major(
|
||||
TransB, TransA,
|
||||
N, M, K,
|
||||
alpha,
|
||||
B, ldb,
|
||||
A, lda,
|
||||
beta,
|
||||
C, ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> generate_random_matrix(int rows, int cols)
|
||||
{
|
||||
std::vector<float> m(rows * cols);
|
||||
|
||||
@@ -118,6 +118,16 @@ namespace Blas {
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
);
|
||||
|
||||
void sgemm(
|
||||
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
|
||||
const int M, const int N, const int K,
|
||||
const float alpha,
|
||||
const float * SF_BLAS_RESTRICT A, const int lda,
|
||||
const float * SF_BLAS_RESTRICT B, const int ldb,
|
||||
const float beta,
|
||||
float * SF_BLAS_RESTRICT C, const int ldc
|
||||
);
|
||||
|
||||
void test(
|
||||
ThreadPool& thread_pool
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user