Mentions légales du service

Skip to content
Snippets Groups Projects
Commit df26aea1 authored by hhakim's avatar hhakim
Browse files

Add unit tests for MatBSR::transpose, conjugate, adjoint, getNonZeros,...

Add unit tests for MatBSR::transpose, conjugate, adjoint, getNonZeros, getNBytes, operator*=(FPP), get_col.
parent 3370290f
No related branches found
No related tags found
No related merge requests found
......@@ -130,6 +130,7 @@ void test_mul_sparse(const MatBSR<FPP, Cpu>& bmat)
cout << "=== Testing MatBSR::multiply(MatSparse, N)" << endl;
auto rmd = MatDense<FPP, Cpu>::randMat(bmat.getNbCol(), 32);
MatSparse<FPP, Cpu> rms(*rmd);
std::cout << "sp has nan:" << rms.containsNaN() << std::endl;
MatSparse<FPP, Cpu> rms_copy(rms);
auto dmat = bmat.to_dense();
dmat.multiply(rms, 'N');
......@@ -220,6 +221,7 @@ void test_gemm_NT(const MatBSR<FPP, Cpu>& bmat)
cout << "=== Testing MatBSR::faust_gemmNT" << endl;
auto B = MatDense<FPP, Cpu>::randMat(32, bmat.getNbCol());
auto C = MatDense<FPP, Cpu>::randMat(bmat.getNbRow(), 32);
std::cout << "ds has nan:" << B->containsNaN() << std::endl;
MatDense<FPP, Cpu> C_copy(*C);
auto dmat = bmat.to_dense();
bmat.faust_gemm(*B, *C, (*B)(0,0), (*C)(0,0), 'N', 'T');
......@@ -266,6 +268,109 @@ void test_gemm_HH(const MatBSR<FPP, Cpu>& bmat)
cout << "OK" << endl;
}
void test_transpose(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::transpose" << endl;
auto dmat = bmat.to_dense();
MatBSR<FPP, Cpu> bmat_t = bmat;
bmat.to_dense().print_file("bmat.txt");
bmat_t.transpose();
dmat.transpose();
assert(bmat_t.getNbRow() == dmat.getNbRow() && bmat_t.getNbCol() == dmat.getNbCol());
MatDense<FPP, Cpu> test = bmat_t.to_dense();
test -= dmat;
dmat.print_file("dmat.txt");
bmat_t.to_dense().print_file("bmat_t.txt");
assert(test.norm() < 1e-6);
cout << "OK" << endl;
}
void test_conjugate(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::conjugate" << endl;
auto dmat = bmat.to_dense();
MatBSR<FPP, Cpu> bmat_c = bmat;
bmat_c.conjugate();
dmat.conjugate();
assert(bmat_c.getNbRow() == dmat.getNbRow() && bmat_c.getNbCol() == dmat.getNbCol());
MatDense<FPP, Cpu> test = bmat_c.to_dense();
test -= dmat;
assert(test.norm() < 1e-6);
cout << "OK" << endl;
}
void test_adjoint(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::adjoint" << endl;
auto dmat = bmat.to_dense();
MatBSR<FPP, Cpu> bmat_a = bmat;
bmat_a.adjoint();
dmat.adjoint();
assert(bmat_a.getNbRow() == dmat.getNbRow() && bmat_a.getNbCol() == dmat.getNbCol());
MatDense<FPP, Cpu> test = bmat_a.to_dense();
test -= dmat;
assert(test.norm() < 1e-6);
cout << "OK" << endl;
}
void test_nnz(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::getNonZeros" << endl;
auto dmat = bmat.to_dense();
assert(bmat.getNonZeros() == dmat.getNonZeros());
cout << "OK" << endl;
}
void test_nbytes(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::getNBytes" << endl;
auto dmat = bmat.to_dense();
assert(bmat.getNBytes() == bmat.getNBlocks()*bmat.getNbBlockRow()*bmat.getNbBlockCol()*sizeof(FPP)+(bmat.getNbBlocksPerDim(0)+1+bmat.getNBlocks())*sizeof(int));
cout << "OK" << endl;
}
void test_get_type(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::getType" << endl;
assert(bmat.getType() == BSR);
cout << "OK" << endl;
}
void test_mul_scal(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::operator*=(FPP)" << endl;
auto rmd = MatDense<FPP, Cpu>::randMat(bmat.getNbCol(), 32);
MatBSR<FPP, Cpu> bmat_copy(bmat);
auto dmat = bmat.to_dense();
FPP scal = (*rmd)(0,0);
dmat *= scal;
bmat_copy *= scal;
MatDense<FPP, Cpu> test(dmat);
test -= bmat_copy.to_dense();
assert(test.norm() < 1e-6);
delete rmd;
cout << "OK" << endl;
}
void test_get_col(const MatBSR<FPP, Cpu>& bmat)
{
cout << "=== Testing MatBSR::get_col" << endl;
int n = bmat.getNbRow();
Vect<FPP, Cpu> bcol(n);
Vect<FPP, Cpu> dcol(n);
Vect<FPP, Cpu> test(n);
auto dmat = bmat.to_dense();
for(int j=0;j<bmat.getNbCol(); j++)
{
dcol = dmat.get_col(j);
bcol = bmat.get_col(j);
test = dcol;
test -= bcol;
assert(test.norm() < 1e-6);
}
cout << "OK" << endl;
}
int main(int argc, char** argv)
{
int m, n, bm, bn, bnnz;
......@@ -292,6 +397,14 @@ int main(int argc, char** argv)
test_gemm_NT(*bmat);
test_gemm_TT(*bmat);
test_gemm_HH(*bmat);
test_transpose(*bmat);
test_conjugate(*bmat);
test_adjoint(*bmat);
test_nnz(*bmat);
test_nbytes(*bmat);
test_get_type(*bmat);
test_mul_scal(*bmat);
test_get_col(*bmat);
delete bmat;
return EXIT_SUCCESS;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment