Add single threaded sgemm.

This commit is contained in:
Tomasz Sobczyk
2020-11-22 18:27:51 +01:00
committed by nodchip
parent ef4fdb40f9
commit 4ea8572b6d
2 changed files with 300 additions and 0 deletions

View File

@@ -546,6 +546,156 @@ namespace Blas {
);
thread_pool.wait_for_workers_finished();
#endif
}
void sgemm_row_major_transpose_right(
const int M, const int N, const int K,
const float alpha,
const float * SF_BLAS_RESTRICT A, const int lda,
const float * SF_BLAS_RESTRICT B, const int ldb,
const float beta,
float * SF_BLAS_RESTRICT C, const int ldc
)
{
#if defined(USE_SSE3)
const __m128 alpha4 = _mm_set1_ps(alpha);
const __m128 beta4 = _mm_set1_ps(beta);
for (int m = 0; m < M - 1; m += 2)
{
int n = 0;
for (; n < N - 3; n += 4)
{
// mn
__m128 sum00 = _mm_setzero_ps();
__m128 sum01 = _mm_setzero_ps();
__m128 sum02 = _mm_setzero_ps();
__m128 sum03 = _mm_setzero_ps();
__m128 sum10 = _mm_setzero_ps();
__m128 sum11 = _mm_setzero_ps();
__m128 sum12 = _mm_setzero_ps();
__m128 sum13 = _mm_setzero_ps();
// Horizontal sum of elements in sum[m][n] corresponds to
// the final element in the C.
int k = 0;
for (; k < K - 3; k += 4)
{
const __m128 a0 = _mm_loadu_ps(&A[(m+0)*lda+k+0]);
const __m128 a1 = _mm_loadu_ps(&A[(m+1)*lda+k+0]);
const __m128 b0 = _mm_loadu_ps(&B[(n+0)*ldb+k+0]);
const __m128 b1 = _mm_loadu_ps(&B[(n+1)*ldb+k+0]);
const __m128 b2 = _mm_loadu_ps(&B[(n+2)*ldb+k+0]);
const __m128 b3 = _mm_loadu_ps(&B[(n+3)*ldb+k+0]);
sum00 = _mm_add_ps(sum00, _mm_mul_ps(a0, b0));
sum01 = _mm_add_ps(sum01, _mm_mul_ps(a0, b1));
sum02 = _mm_add_ps(sum02, _mm_mul_ps(a0, b2));
sum03 = _mm_add_ps(sum03, _mm_mul_ps(a0, b3));
sum10 = _mm_add_ps(sum10, _mm_mul_ps(a1, b0));
sum11 = _mm_add_ps(sum11, _mm_mul_ps(a1, b1));
sum12 = _mm_add_ps(sum12, _mm_mul_ps(a1, b2));
sum13 = _mm_add_ps(sum13, _mm_mul_ps(a1, b3));
}
for(; k < K; k += 1)
{
const float a0 = A[(m+0)*lda+k+0];
const float a1 = A[(m+1)*lda+k+0];
const float b0 = B[(n+0)*ldb+k+0];
const float b1 = B[(n+1)*ldb+k+0];
const float b2 = B[(n+2)*ldb+k+0];
const float b3 = B[(n+3)*ldb+k+0];
// Since all will be summed vertically anyway we can
// just add to the first element.
// Other elements are left unmodified.
sum00 = _mm_add_ss(sum00, _mm_set_ss(a0 * b0));
sum01 = _mm_add_ss(sum01, _mm_set_ss(a0 * b1));
sum02 = _mm_add_ss(sum02, _mm_set_ss(a0 * b2));
sum03 = _mm_add_ss(sum03, _mm_set_ss(a0 * b3));
sum10 = _mm_add_ss(sum10, _mm_set_ss(a1 * b0));
sum11 = _mm_add_ss(sum11, _mm_set_ss(a1 * b1));
sum12 = _mm_add_ss(sum12, _mm_set_ss(a1 * b2));
sum13 = _mm_add_ss(sum13, _mm_set_ss(a1 * b3));
}
__m128 s0 = m128_hadd_ps(sum00, sum01, sum02, sum03);
__m128 s1 = m128_hadd_ps(sum10, sum11, sum12, sum13);
s0 = _mm_mul_ps(s0, alpha4);
s1 = _mm_mul_ps(s1, alpha4);
__m128 c0 = _mm_loadu_ps(&C[(m+0)*ldc+(n+0)]);
__m128 c1 = _mm_loadu_ps(&C[(m+1)*ldc+(n+0)]);
c0 = _mm_mul_ps(c0, beta4);
c1 = _mm_mul_ps(c1, beta4);
c0 = _mm_add_ps(c0, s0);
c1 = _mm_add_ps(c1, s1);
_mm_storeu_ps(&C[(m+0)*ldc+(n+0)], c0);
_mm_storeu_ps(&C[(m+1)*ldc+(n+0)], c1);
}
for(; n < N; n += 1)
{
float sum0 = 0.0f;
float sum1 = 0.0f;
for (int k = 0; k < K; ++k)
{
const float a0 = A[(m+0)*lda+k+0];
const float a1 = A[(m+1)*lda+k+0];
const float b0 = B[(n+0)*ldb+k+0];
sum0 += a0 * b0;
sum1 += a1 * b0;
}
C[(m+0)*ldc+(n+0)] = C[(m+0)*ldc+(n+0)] * beta + sum0 * alpha;
C[(m+1)*ldc+(n+0)] = C[(m+1)*ldc+(n+0)] * beta + sum1 * alpha;
}
}
for (; m < M; m += 1)
{
for (int n = 0; n < N; n += 1)
{
float sum = 0.0f;
for (int k = 0; k < K; k += 1)
{
sum += A[m*lda + k] * B[n*ldb + k];
}
C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha;
}
}
#else
for (int m = 0; m < M; m += 1)
{
for (int n = 0; n < N; n += 1)
{
float sum = 0.0f;
for (int k = 0; k < K; k += 1)
{
sum += A[m*lda + k] * B[n*ldb + k];
}
C[m*ldc + n] = C[m*ldc + n] * beta + sum * alpha;
}
}
#endif
}
@@ -605,6 +755,35 @@ namespace Blas {
);
}
void sgemm_row_major_transpose_none(
const int M, const int N, const int K,
const float alpha,
const float * SF_BLAS_RESTRICT A, const int lda,
const float * SF_BLAS_RESTRICT B, const int ldb,
const float beta,
float * SF_BLAS_RESTRICT C, const int ldc
)
{
constexpr static int temporary_buffer_index = 1;
auto B_tr = get_thread_local_temporary_storage(K * N, temporary_buffer_index);
transpose(
K, N,
B, ldb,
B_tr, K
);
sgemm_row_major_transpose_right(
M, N, K,
alpha,
A, lda,
B_tr, K,
beta,
C, ldc
);
}
void sgemm_row_major(
ThreadPool& thread_pool,
MatrixTranspose TransA, MatrixTranspose TransB,
@@ -684,6 +863,80 @@ namespace Blas {
}
}
void sgemm_row_major(
MatrixTranspose TransA, MatrixTranspose TransB,
const int M, const int N, const int K,
const float alpha,
const float * SF_BLAS_RESTRICT A, const int lda,
const float * SF_BLAS_RESTRICT B, const int ldb,
const float beta,
float * SF_BLAS_RESTRICT C, const int ldc
)
{
constexpr static int temporary_buffer_index = 0;
if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::Trans)
{
auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index);
transpose(
K, M,
A, lda,
A_tr, K
);
sgemm_row_major_transpose_right(
M, N, K,
alpha,
A_tr, K,
B, ldb,
beta,
C, ldc
);
}
else if (TransA == MatrixTranspose::NoTrans && TransB == MatrixTranspose::Trans)
{
sgemm_row_major_transpose_right(
M, N, K,
alpha,
A, lda,
B, ldb,
beta,
C, ldc
);
}
else if (TransA == MatrixTranspose::Trans && TransB == MatrixTranspose::NoTrans)
{
auto A_tr = get_thread_local_temporary_storage(K * M, temporary_buffer_index);
transpose(
K, M,
A, lda,
A_tr, K
);
sgemm_row_major_transpose_none(
M, N, K,
alpha,
A_tr, K,
B, ldb,
beta,
C, ldc
);
}
else // no transpositions
{
sgemm_row_major_transpose_none(
M, N, K,
alpha,
A, lda,
B, ldb,
beta,
C, ldc
);
}
}
void sgemm(
ThreadPool& thread_pool,
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
@@ -723,6 +976,43 @@ namespace Blas {
}
}
void sgemm(
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
const int M, const int N, const int K,
const float alpha,
const float * SF_BLAS_RESTRICT A, const int lda,
const float * SF_BLAS_RESTRICT B, const int ldb,
const float beta,
float * SF_BLAS_RESTRICT C, const int ldc
)
{
if (layout == MatrixLayout::RowMajor)
{
sgemm_row_major(
TransA, TransB,
M, N, K,
alpha,
A, lda,
B, ldb,
beta,
C, ldc
);
}
else
{
sgemm_row_major(
TransB, TransA,
N, M, K,
alpha,
B, ldb,
A, lda,
beta,
C, ldc
);
}
}
std::vector<float> generate_random_matrix(int rows, int cols)
{
std::vector<float> m(rows * cols);

View File

@@ -118,6 +118,16 @@ namespace Blas {
float * SF_BLAS_RESTRICT C, const int ldc
);
void sgemm(
MatrixLayout layout, MatrixTranspose TransA, MatrixTranspose TransB,
const int M, const int N, const int K,
const float alpha,
const float * SF_BLAS_RESTRICT A, const int lda,
const float * SF_BLAS_RESTRICT B, const int ldb,
const float beta,
float * SF_BLAS_RESTRICT C, const int ldc
);
void test(
ThreadPool& thread_pool
);