Mentions légales du service

Skip to content
Snippets Groups Projects
Commit bcd227fa authored by hhakim's avatar hhakim
Browse files

Add GPU batched_SVD updating to gpu_mod@4bfc92d8

parent 5f252c1f
Branches
Tags
No related merge requests found
Subproject commit a33c81e63bc4b62498b66587d2c83824b9167594
Subproject commit 4bfc92d8060fb25720e42b012136755f7e14a1aa
......@@ -767,7 +767,6 @@ namespace Faust
dsm_funcs->free(real_mat.gpu_mat);
real_mat.gpu_mat = real_gpu_mat;
}
template<>
void butterfly_diag_prod(MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>& X, const Vect<@FAUST_SCALAR_FOR_GM@, GPU2>& d1, const Vect<@FAUST_SCALAR_FOR_GM@, GPU2>& d2, const int* ids)
......@@ -775,4 +774,11 @@ namespace Faust
auto dsm_funcs = GPUModHandler::get_singleton()->dsm_funcs(@FAUST_SCALAR_FOR_GM@(0));
dsm_funcs->butterfly_diag_prod(X.gpu_mat, d1.gpu_mat, d2.gpu_mat, ids);
}
template<>
void batched_svd(MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>& As, const uint32_t nbatches, MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>& Us, MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>& Vs, MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>& Ss, const uint32_t rank /*= 0*/)
{
auto dsm_funcs = GPUModHandler::get_singleton()->dsm_funcs(@FAUST_SCALAR_FOR_GM@(0));
dsm_funcs->batched_svd(As.gpu_mat, nbatches, Us.gpu_mat, Vs.gpu_mat, Ss.gpu_mat, rank);
}
}
......@@ -19,6 +19,10 @@ namespace Faust
template<typename FPP>
void butterfly_diag_prod(MatDense<FPP, GPU2>& X, const Vect<FPP, GPU2>& d1, const Vect<FPP, GPU2>& d2, const int* ids);
template<typename FPP>
void batched_svd(MatDense<FPP, GPU2>& As, const uint32_t nbatches, MatDense<FPP, GPU2>& Us, MatDense<FPP, GPU2>& Vs, MatDense<FPP, GPU2>& Ss, const uint32_t rank = 0);
template<typename FPP>
class MatDense<FPP, GPU2> : public MatGeneric<FPP,GPU2>
{
......@@ -27,6 +31,7 @@ namespace Faust
friend MatBSR<FPP,GPU2>;
friend MatDense<std::complex<double>,GPU2>; // TODO limit to real function
friend void butterfly_diag_prod<>(MatDense<FPP, GPU2>& X, const Vect<FPP, GPU2>& d1, const Vect<FPP, GPU2>& d2, const int* ids);
friend void batched_svd<>(MatDense<FPP, GPU2>& As, const uint32_t nbatches, MatDense<FPP, GPU2>& Us, MatDense<FPP, GPU2>& Vs, MatDense<FPP, GPU2>& Ss, const uint32_t rank /*= 0*/);
// friend void gemm<>(const MatDense<FPP, GPU2> &A, const MatDense<FPP, GPU2> &B, MatDense<FPP, GPU2> &C, const FPP& alpha, const FPP& beta, const char opA, const char opB);
//
// friend void gemv<>(const MatDense<FPP, GPU2> &A, const Vect<FPP, GPU2> &B, Vect<FPP, GPU2> &C, const FPP& alpha, const FPP& beta, const char opA, const char opB);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment