diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index 3db4b47e8e..2baf73a4ec 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -187,6 +187,118 @@ void cgemm_mt_(const char* transa, cluster_id); } // cgemm that needn't malloc_ht or free_ht +void sgemv_mt_(const char* transa, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* x, + const int* incx, + const float* beta, + float* y, + const int* incy, + int cluster_id) +{ + mtblas_sgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy, + cluster_id); +} + +void dgemv_mt_(const char* transa, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* x, + const int* incx, + const double* beta, + double* y, + const int* incy, + int cluster_id) +{ + mtblas_dgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy, + cluster_id); +} + +void zgemv_mt_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id) +{ + mtblas_zgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)x, + *incx, + (const void*)beta, + (void*)y, + *incy, + cluster_id); +} + +void cgemv_mt_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id) +{ + mtblas_cgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + (const void*)alpha, + (const void*)a, + *lda, + (const void*)x, + *incx, + (const void*)beta, + (void*)y, + *incy, + cluster_id); +} + // Used to replace original free void sgemm_mth_(const char* transa, @@ -330,4 +442,132 @@ void cgemm_mth_(const char* transa, free_ht(alp); free_ht(bet); } // cgemm that needn't malloc_ht or free_ht + +void sgemv_mth_(const char* transa, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* x, + const int* incx, + const float* beta, + float* y, + const int* incy, + int cluster_id) +{ + mt_hthread_sgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy, + cluster_id); +} + +void dgemv_mth_(const char* transa, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* x, + const int* incx, + const double* beta, + double* y, + const int* incy, + int cluster_id) +{ + mt_hthread_dgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + *alpha, + a, + *lda, + x, + *incx, + *beta, + y, + *incy, + cluster_id); +} + +void zgemv_mth_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id) +{ + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + + mt_hthread_zgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + (const void*)alp, + (const void*)a, + *lda, + (const void*)x, + *incx, + (const void*)bet, + (void*)y, + *incy, + cluster_id); + + free_ht(alp); + free_ht(bet); +} + +void cgemv_mth_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id) +{ + std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *alp = *alpha; + std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + *bet = *beta; + + mt_hthread_cgemv(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + *m, + *n, + (const void*)alp, + (const void*)a, + *lda, + (const void*)x, + *incx, + (const void*)bet, + (void*)y, + *incy, + cluster_id); + + free_ht(alp); + free_ht(bet); +} } // namespace mtfunc \ No newline at end of file diff --git a/source/source_base/kernels/dsp/dsp_connector.h b/source/source_base/kernels/dsp/dsp_connector.h index 34ccbaec4b..997a21de59 100644 --- a/source/source_base/kernels/dsp/dsp_connector.h +++ b/source/source_base/kernels/dsp/dsp_connector.h @@ -76,6 +76,58 @@ void cgemm_mt_(const char* transa, const int* ldc, int cluster_id); +void sgemv_mt_(const char* transa, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* x, + const int* incx, + const float* beta, + float* y, + const int* incy, + int cluster_id); + +void dgemv_mt_(const char* transa, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* x, + const int* incx, + const double* beta, + double* y, + const int* incy, + int cluster_id); + +void zgemv_mt_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id); + +void cgemv_mt_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id); + void sgemm_mth_(const char* transa, const char* transb, const int* m, @@ -136,6 +188,58 @@ void cgemm_mth_(const char* transa, const int* ldc, int cluster_id); +void sgemv_mth_(const char* transa, + const int* m, + const int* n, + const float* alpha, + const float* a, + const int* lda, + const float* x, + const int* incx, + const float* beta, + float* y, + const int* incy, + int cluster_id); + +void dgemv_mth_(const char* transa, + const int* m, + const int* n, + const double* alpha, + const double* a, + const int* lda, + const double* x, + const int* incx, + const double* beta, + double* y, + const int* incy, + int cluster_id); + +void zgemv_mth_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id); + +void cgemv_mth_(const char* transa, + const int* m, + const int* n, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* x, + const int* incx, + const std::complex* beta, + std::complex* y, + const int* incy, + int cluster_id); + // #define zgemm_ zgemm_mt // The next is dsp utils. It may be moved to other files if this file get too huge diff --git a/source/source_base/kernels/math_kernel_op.cpp b/source/source_base/kernels/math_kernel_op.cpp index aa5d365319..fa1bb4628c 100644 --- a/source/source_base/kernels/math_kernel_op.cpp +++ b/source/source_base/kernels/math_kernel_op.cpp @@ -48,6 +48,25 @@ struct gemm_op }; #ifdef __DSP +template +struct gemv_op_mt +{ + void operator()(const char& trans, + const int& m, + const int& n, + const T* alpha, + const T* A, + const int& lda, + const T* X, + const int& incx, + const T* beta, + T* Y, + const int& incy) + { + BlasConnector::gemv(trans, m, n, *alpha, A, lda, X, incx, *beta, Y, incy, base_device::AbacusDevice_t::DspDevice); + } +}; + template struct gemm_op_mt { @@ -163,6 +182,8 @@ template struct matrix_mul_vector_op, base_device::DEVICE_C template struct matrixTranspose_op; #endif #ifdef __DSP +template struct gemv_op_mt, base_device::DEVICE_CPU>; +template struct gemv_op_mt, base_device::DEVICE_CPU>; template struct gemm_op_mt, base_device::DEVICE_CPU>; template struct gemm_op_mt, base_device::DEVICE_CPU>; #endif diff --git a/source/source_base/kernels/math_kernel_op.h b/source/source_base/kernels/math_kernel_op.h index f5c8e218df..120844f3f9 100644 --- a/source/source_base/kernels/math_kernel_op.h +++ b/source/source_base/kernels/math_kernel_op.h @@ -241,6 +241,31 @@ template struct gemm_op { }; #ifdef __DSP +// compute Y = alpha * op(A) * X + beta * Y on DSP Hardware +template struct gemv_op_mt { + /// @brief Y = alpha * op(A) * X + beta * Y + /// + /// Input Parameters + /// \param trans : whether to transpose matrix A + /// \param m : row number of A + /// \param n : column number of A + /// \param alpha : input constant alpha + /// \param A : input matrix A + /// \param lda : leading dimension of A + /// \param X : input vector X + /// \param incx : increment of X + /// \param beta : input constant beta + /// \param Y : input vector Y + /// \param incy : increment of Y + /// + /// Output Parameters + /// \param Y : output vector Y + void operator()(const char &trans, const int &m, + const int &n, const T *alpha, const T *A, const int &lda, + const T *X, const int &incx, const T *beta, T *Y, + const int &incy); +}; + // compute C = alpha * op(A) * op(B) + beta * C on DSP Hardware template struct gemm_op_mt { /// @brief C = alpha * op(A) * op(B) + beta * C diff --git a/source/source_base/module_external/blas_connector_matrix.cpp b/source/source_base/module_external/blas_connector_matrix.cpp index 674a7c61d9..3b18d3ee3a 100644 --- a/source/source_base/module_external/blas_connector_matrix.cpp +++ b/source/source_base/module_external/blas_connector_matrix.cpp @@ -506,6 +506,22 @@ void BlasConnector::gemv(const char trans, const int m, const int n, if (device_type == base_device::AbacusDevice_t::CpuDevice) { sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + mtfunc::sgemv_mth_(&trans, + &m, + &n, + &alpha, + A, + &lda, + X, + &incx, + &beta, + Y, + &incy, + GlobalV::MY_RANK % PARAM.inp.dsp_count); + } +#endif #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); @@ -524,6 +540,22 @@ void BlasConnector::gemv(const char trans, const int m, const int n, if (device_type == base_device::AbacusDevice_t::CpuDevice) { dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + mtfunc::dgemv_mth_(&trans, + &m, + &n, + &alpha, + A, + &lda, + X, + &incx, + &beta, + Y, + &incy, + GlobalV::MY_RANK % PARAM.inp.dsp_count); + } +#endif #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op"); @@ -542,6 +574,22 @@ void BlasConnector::gemv(const char trans, const int m, const int n, if (device_type == base_device::AbacusDevice_t::CpuDevice) { cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + mtfunc::cgemv_mth_(&trans, + &m, + &n, + &alpha, + A, + &lda, + X, + &incx, + &beta, + Y, + &incy, + GlobalV::MY_RANK % PARAM.inp.dsp_count); + } +#endif #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag()); @@ -562,6 +610,22 @@ void BlasConnector::gemv(const char trans, const int m, const int n, if (device_type == base_device::AbacusDevice_t::CpuDevice) { zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy); } +#ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + mtfunc::zgemv_mth_(&trans, + &m, + &n, + &alpha, + A, + &lda, + X, + &incx, + &beta, + Y, + &incy, + GlobalV::MY_RANK % PARAM.inp.dsp_count); + } +#endif #ifdef __CUDA else if (device_type == base_device::AbacusDevice_t::GpuDevice) { cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag());