Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,118 @@ void cgemm_mt_(const char* transa,
cluster_id);
} // cgemm that needn't malloc_ht or free_ht

void sgemv_mt_(const char* transa,
const int* m,
const int* n,
const float* alpha,
const float* a,
const int* lda,
const float* x,
const int* incx,
const float* beta,
float* y,
const int* incy,
int cluster_id)
{
mtblas_sgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy,
cluster_id);
}

void dgemv_mt_(const char* transa,
const int* m,
const int* n,
const double* alpha,
const double* a,
const int* lda,
const double* x,
const int* incx,
const double* beta,
double* y,
const int* incy,
int cluster_id)
{
mtblas_dgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy,
cluster_id);
}

void zgemv_mt_(const char* transa,
const int* m,
const int* n,
const std::complex<double>* alpha,
const std::complex<double>* a,
const int* lda,
const std::complex<double>* x,
const int* incx,
const std::complex<double>* beta,
std::complex<double>* y,
const int* incy,
int cluster_id)
{
mtblas_zgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
(const void*)alpha,
(const void*)a,
*lda,
(const void*)x,
*incx,
(const void*)beta,
(void*)y,
*incy,
cluster_id);
}

void cgemv_mt_(const char* transa,
const int* m,
const int* n,
const std::complex<float>* alpha,
const std::complex<float>* a,
const int* lda,
const std::complex<float>* x,
const int* incx,
const std::complex<float>* beta,
std::complex<float>* y,
const int* incy,
int cluster_id)
{
mtblas_cgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
(const void*)alpha,
(const void*)a,
*lda,
(const void*)x,
*incx,
(const void*)beta,
(void*)y,
*incy,
cluster_id);
}

// Used to replace original free

void sgemm_mth_(const char* transa,
Expand Down Expand Up @@ -330,4 +442,132 @@ void cgemm_mth_(const char* transa,
free_ht(alp);
free_ht(bet);
} // cgemm that needn't malloc_ht or free_ht

void sgemv_mth_(const char* transa,
const int* m,
const int* n,
const float* alpha,
const float* a,
const int* lda,
const float* x,
const int* incx,
const float* beta,
float* y,
const int* incy,
int cluster_id)
{
mt_hthread_sgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy,
cluster_id);
}

void dgemv_mth_(const char* transa,
const int* m,
const int* n,
const double* alpha,
const double* a,
const int* lda,
const double* x,
const int* incx,
const double* beta,
double* y,
const int* incy,
int cluster_id)
{
mt_hthread_dgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
*alpha,
a,
*lda,
x,
*incx,
*beta,
y,
*incy,
cluster_id);
}

void zgemv_mth_(const char* transa,
const int* m,
const int* n,
const std::complex<double>* alpha,
const std::complex<double>* a,
const int* lda,
const std::complex<double>* x,
const int* incx,
const std::complex<double>* beta,
std::complex<double>* y,
const int* incy,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;

mt_hthread_zgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
(const void*)alp,
(const void*)a,
*lda,
(const void*)x,
*incx,
(const void*)bet,
(void*)y,
*incy,
cluster_id);

free_ht(alp);
free_ht(bet);
}

void cgemv_mth_(const char* transa,
const int* m,
const int* n,
const std::complex<float>* alpha,
const std::complex<float>* a,
const int* lda,
const std::complex<float>* x,
const int* incx,
const std::complex<float>* beta,
std::complex<float>* y,
const int* incy,
int cluster_id)
{
std::complex<float>* alp = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*alp = *alpha;
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
*bet = *beta;

mt_hthread_cgemv(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
*m,
*n,
(const void*)alp,
(const void*)a,
*lda,
(const void*)x,
*incx,
(const void*)bet,
(void*)y,
*incy,
cluster_id);

free_ht(alp);
free_ht(bet);
}
} // namespace mtfunc
104 changes: 104 additions & 0 deletions source/source_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,58 @@ void cgemm_mt_(const char* transa,
const int* ldc,
int cluster_id);

void sgemv_mt_(const char* transa,
const int* m,
const int* n,
const float* alpha,
const float* a,
const int* lda,
const float* x,
const int* incx,
const float* beta,
float* y,
const int* incy,
int cluster_id);

void dgemv_mt_(const char* transa,
const int* m,
const int* n,
const double* alpha,
const double* a,
const int* lda,
const double* x,
const int* incx,
const double* beta,
double* y,
const int* incy,
int cluster_id);

void zgemv_mt_(const char* transa,
const int* m,
const int* n,
const std::complex<double>* alpha,
const std::complex<double>* a,
const int* lda,
const std::complex<double>* x,
const int* incx,
const std::complex<double>* beta,
std::complex<double>* y,
const int* incy,
int cluster_id);

void cgemv_mt_(const char* transa,
const int* m,
const int* n,
const std::complex<float>* alpha,
const std::complex<float>* a,
const int* lda,
const std::complex<float>* x,
const int* incx,
const std::complex<float>* beta,
std::complex<float>* y,
const int* incy,
int cluster_id);

void sgemm_mth_(const char* transa,
const char* transb,
const int* m,
Expand Down Expand Up @@ -136,6 +188,58 @@ void cgemm_mth_(const char* transa,
const int* ldc,
int cluster_id);

void sgemv_mth_(const char* transa,
const int* m,
const int* n,
const float* alpha,
const float* a,
const int* lda,
const float* x,
const int* incx,
const float* beta,
float* y,
const int* incy,
int cluster_id);

void dgemv_mth_(const char* transa,
const int* m,
const int* n,
const double* alpha,
const double* a,
const int* lda,
const double* x,
const int* incx,
const double* beta,
double* y,
const int* incy,
int cluster_id);

void zgemv_mth_(const char* transa,
const int* m,
const int* n,
const std::complex<double>* alpha,
const std::complex<double>* a,
const int* lda,
const std::complex<double>* x,
const int* incx,
const std::complex<double>* beta,
std::complex<double>* y,
const int* incy,
int cluster_id);

void cgemv_mth_(const char* transa,
const int* m,
const int* n,
const std::complex<float>* alpha,
const std::complex<float>* a,
const int* lda,
const std::complex<float>* x,
const int* incx,
const std::complex<float>* beta,
std::complex<float>* y,
const int* incy,
int cluster_id);

// #define zgemm_ zgemm_mt

// The next is dsp utils. It may be moved to other files if this file get too huge
Expand Down
Loading
Loading