Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 7f524cee authored by Mathieu Faverge's avatar Mathieu Faverge
Browse files

Fix Conjuguate Transpose on hermitian matrices

parent 2e02e6cb
No related branches found
No related tags found
No related merge requests found
......@@ -53,24 +53,26 @@ struct __spm_zmatvec_s {
spm_complex64_t *y;
spm_int_t incy;
__conj_fct_t conj_fct;
__conj_fct_t conjA_fct;
__conj_fct_t conjAt_fct;
__loop_fct_t loop_fct;
};
static inline int
__spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( col=0; col<n; col++, colptr++ )
......@@ -80,11 +82,11 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjAt_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
else {
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
......@@ -94,17 +96,18 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( col=0; col<n; col++, colptr++ )
......@@ -114,11 +117,11 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ];
}
else {
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
......@@ -128,17 +131,17 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
spm_int_t col, row, i;
if ( args->follow_x ) {
......@@ -147,7 +150,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
{
row = *rowptr - baseval;
y[ row * incy ] += alpha * conj_fct( *values ) * (*x);
y[ row * incy ] += alpha * conjA_fct( *values ) * (*x);
}
}
}
......@@ -157,7 +160,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
{
row = *rowptr - baseval;
*y += alpha * conj_fct( *values ) * x[ row * incx ];
*y += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
......@@ -167,17 +170,18 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
......@@ -186,11 +190,11 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
col = *colptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ];
}
else {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
}
}
return SPM_SUCCESS;
......@@ -199,17 +203,17 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
spm_int_t col, row, i;
for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
......@@ -217,7 +221,7 @@ __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
col = *colptr - baseval;
y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
}
return SPM_SUCCESS;
}
......@@ -245,6 +249,117 @@ __spm_zlascl( spm_complex64_t alpha,
#endif
static inline int
__spm_zmatvec_args_init( __spm_zmatvec_t *args,
spm_side_t side,
spm_trans_t transA,
spm_complex64_t alpha,
const spmatrix_t *A,
const spm_complex64_t *B,
spm_int_t ldb,
spm_complex64_t *C,
spm_int_t ldc )
{
spm_int_t incx, incy;
if ( side == SpmLeft ) {
incx = 1;
incy = 1;
}
else {
incx = ldb;
incy = ldc;
}
args->follow_x = 0;
args->baseval = spmFindBase( A );
args->n = A->n;
args->nnz = A->nnz;
args->alpha = alpha;
args->rowptr = A->rowptr;
args->colptr = A->colptr;
args->values = A->values;
args->x = B;
args->incx = incx;
args->y = C;
args->incy = incy;
args->conjA_fct = __fct_id;
args->conjAt_fct = __fct_id;
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( A->mtxtype != SpmHermitian ) {
if ( transA == SpmConjTrans ) {
args->conjA_fct = __fct_conj;
args->conjAt_fct = __fct_conj;
}
}
else {
if ( transA == SpmTrans ) {
args->conjA_fct = __fct_conj;
args->conjAt_fct = __fct_id;
}
else {
args->conjA_fct = __fct_id;
args->conjAt_fct = __fct_conj;
}
}
#endif
args->loop_fct = NULL;
switch( A->fmttype ) {
case SpmCSC:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA == SpmNoTrans)) ||
((side == SpmRight) && (transA != SpmNoTrans)) )
{
args->follow_x = 1;
}
else {
args->follow_x = 0;
}
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args->follow_x = 1;
}
else {
args->follow_x = 0;
}
args->colptr = A->rowptr;
args->rowptr = A->colptr;
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
const __conj_fct_t tmp_fct = args->conjA_fct;
args->conjA_fct = args->conjAt_fct;
args->conjAt_fct = tmp_fct;
args->colptr = A->rowptr;
args->rowptr = A->colptr;
}
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
return SPM_SUCCESS;
}
/**
*******************************************************************************
*
......@@ -342,7 +457,7 @@ spm_zspmm( spm_side_t side,
spm_int_t ldc )
{
int rc = SPM_SUCCESS;
spm_int_t M, N, incx, incy, ldx, ldy, r;
spm_int_t M, N, ldx, ldy, r;
__spm_zmatvec_t args;
if ( transB != SpmNoTrans ) {
......@@ -355,8 +470,6 @@ spm_zspmm( spm_side_t side,
M = A->n;
N = K;
incx = 1;
incy = 1;
ldx = ldb;
ldy = ldc;
}
......@@ -364,8 +477,6 @@ spm_zspmm( spm_side_t side,
M = K;
N = A->n;
incx = ldb;
incy = ldc;
ldx = 1;
ldy = 1;
}
......@@ -381,76 +492,8 @@ spm_zspmm( spm_side_t side,
return SPM_SUCCESS;
}
{
args.follow_x = 0;
args.baseval = spmFindBase( A );
args.n = A->n;
args.nnz = A->nnz;
args.alpha = alpha;
args.rowptr = A->rowptr;
args.colptr = A->colptr;
args.values = A->values;
args.x = B;
args.incx = incx;
args.y = C;
args.incy = incy;
args.conj_fct = __fct_id;
args.loop_fct = NULL;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( ( transA == SpmConjTrans ) ||
( A->mtxtype == SpmHermitian ) )
{
args.conj_fct = __fct_conj;
}
#endif
switch( A->fmttype ) {
case SpmCSC:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA == SpmNoTrans)) ||
((side == SpmRight) && (transA != SpmNoTrans)) )
{
args.follow_x = 1;
}
else {
args.follow_x = 0;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args.follow_x = 1;
}
else {
args.follow_x = 0;
}
args.colptr = A->rowptr;
args.rowptr = A->colptr;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args.colptr = A->rowptr;
args.rowptr = A->colptr;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
__spm_zmatvec_args_init( &args, side, transA,
alpha, A, B, ldb, C, ldc );
for( r=0; (r < N) && (rc == SPM_SUCCESS); r++ ) {
args.x = B + r * ldx;
......@@ -528,59 +571,8 @@ spm_zspmv( spm_trans_t trans,
return SPM_SUCCESS;
}
{
args.follow_x = 0;
args.baseval = spmFindBase( A );
args.n = A->n;
args.nnz = A->nnz;
args.alpha = alpha;
args.rowptr = A->rowptr;
args.colptr = A->colptr;
args.values = A->values;
args.x = x;
args.incx = incx;
args.y = y;
args.incy = incy;
args.conj_fct = __fct_id;
args.loop_fct = NULL;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( ( trans == SpmConjTrans ) ||
( A->mtxtype == SpmHermitian ) )
{
args.conj_fct = __fct_conj;
}
#endif
switch( A->fmttype ) {
case SpmCSC:
{
args.follow_x = (trans == SpmNoTrans) ? 1 : 0;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
args.follow_x = (trans == SpmNoTrans) ? 0 : 1;
args.colptr = A->rowptr;
args.rowptr = A->colptr;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( trans != SpmNoTrans ) {
args.colptr = A->rowptr;
args.rowptr = A->colptr;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
__spm_zmatvec_args_init( &args, SpmLeft, trans,
alpha, A, x, incx, y, incy );
rc = args.loop_fct( &args );
......
......@@ -97,10 +97,6 @@ int main (int argc, char **argv)
{
continue;
}
if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) )
{
continue;
}
printf(" Case %s - %s - %d - %s:\n",
mtxnames[mtxtype - SpmGeneral], fmtnames[fmttype - SpmCSC],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment