Mentions légales du service

Skip to content
Snippets Groups Projects

Fix Conjuguate Transpose on hermitian matrices

Merged Mathieu Faverge requested to merge faverge/spm:fix/spmv_conj into master
All threads resolved!
2 files
+ 188
200
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 188
196
@@ -53,24 +53,26 @@ struct __spm_zmatvec_s {
spm_complex64_t *y;
spm_int_t incy;
__conj_fct_t conj_fct;
__conj_fct_t conjA_fct;
__conj_fct_t conjAt_fct;
__loop_fct_t loop_fct;
};
static inline int
__spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( col=0; col<n; col++, colptr++ )
@@ -80,11 +82,11 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjAt_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
else {
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
@@ -94,17 +96,18 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( col=0; col<n; col++, colptr++ )
@@ -114,11 +117,11 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ];
}
else {
y[ col * incy ] += alpha * ( *values ) * x[ row * incx ];
y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
@@ -128,17 +131,17 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t n = args->n;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
spm_int_t col, row, i;
if ( args->follow_x ) {
@@ -147,7 +150,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
{
row = *rowptr - baseval;
y[ row * incy ] += alpha * conj_fct( *values ) * (*x);
y[ row * incy ] += alpha * conjA_fct( *values ) * (*x);
}
}
}
@@ -157,7 +160,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
{
row = *rowptr - baseval;
*y += alpha * conj_fct( *values ) * x[ row * incx ];
*y += alpha * conjA_fct( *values ) * x[ row * incx ];
}
}
}
@@ -167,17 +170,18 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
const __conj_fct_t conjAt_fct = args->conjAt_fct;
spm_int_t col, row, i;
for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
@@ -186,11 +190,11 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
col = *colptr - baseval;
if ( row != col ) {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ];
}
else {
y[ row * incy ] += alpha * ( *values ) * x[ col * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
}
}
return SPM_SUCCESS;
@@ -199,17 +203,17 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
static inline int
__spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
{
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conj_fct = args->conj_fct;
spm_int_t baseval = args->baseval;
spm_int_t nnz = args->nnz;
spm_complex64_t alpha = args->alpha;
const spm_int_t *rowptr = args->rowptr;
const spm_int_t *colptr = args->colptr;
const spm_complex64_t *values = args->values;
const spm_complex64_t *x = args->x;
spm_int_t incx = args->incx;
spm_complex64_t *y = args->y;
spm_int_t incy = args->incy;
const __conj_fct_t conjA_fct = args->conjA_fct;
spm_int_t col, row, i;
for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
@@ -217,7 +221,7 @@ __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
row = *rowptr - baseval;
col = *colptr - baseval;
y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ];
y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
}
return SPM_SUCCESS;
}
@@ -245,6 +249,117 @@ __spm_zlascl( spm_complex64_t alpha,
#endif
static inline int
__spm_zmatvec_args_init( __spm_zmatvec_t *args,
spm_side_t side,
spm_trans_t transA,
spm_complex64_t alpha,
const spmatrix_t *A,
const spm_complex64_t *B,
spm_int_t ldb,
spm_complex64_t *C,
spm_int_t ldc )
{
spm_int_t incx, incy;
if ( side == SpmLeft ) {
incx = 1;
incy = 1;
}
else {
incx = ldb;
incy = ldc;
}
args->follow_x = 0;
args->baseval = spmFindBase( A );
args->n = A->n;
args->nnz = A->nnz;
args->alpha = alpha;
args->rowptr = A->rowptr;
args->colptr = A->colptr;
args->values = A->values;
args->x = B;
args->incx = incx;
args->y = C;
args->incy = incy;
args->conjA_fct = __fct_id;
args->conjAt_fct = __fct_id;
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( A->mtxtype != SpmHermitian ) {
if ( transA == SpmConjTrans ) {
args->conjA_fct = __fct_conj;
args->conjAt_fct = __fct_conj;
}
}
else {
if ( transA == SpmTrans ) {
args->conjA_fct = __fct_conj;
args->conjAt_fct = __fct_id;
}
else {
args->conjA_fct = __fct_id;
args->conjAt_fct = __fct_conj;
}
}
#endif
args->loop_fct = NULL;
switch( A->fmttype ) {
case SpmCSC:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA == SpmNoTrans)) ||
((side == SpmRight) && (transA != SpmNoTrans)) )
{
args->follow_x = 1;
}
else {
args->follow_x = 0;
}
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args->follow_x = 1;
}
else {
args->follow_x = 0;
}
args->colptr = A->rowptr;
args->rowptr = A->colptr;
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
const __conj_fct_t tmp_fct = args->conjA_fct;
args->conjA_fct = args->conjAt_fct;
args->conjAt_fct = tmp_fct;
args->colptr = A->rowptr;
args->rowptr = A->colptr;
}
args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
return SPM_SUCCESS;
}
/**
*******************************************************************************
*
@@ -342,7 +457,7 @@ spm_zspmm( spm_side_t side,
spm_int_t ldc )
{
int rc = SPM_SUCCESS;
spm_int_t M, N, incx, incy, ldx, ldy, r;
spm_int_t M, N, ldx, ldy, r;
__spm_zmatvec_t args;
if ( transB != SpmNoTrans ) {
@@ -355,8 +470,6 @@ spm_zspmm( spm_side_t side,
M = A->n;
N = K;
incx = 1;
incy = 1;
ldx = ldb;
ldy = ldc;
}
@@ -364,8 +477,6 @@ spm_zspmm( spm_side_t side,
M = K;
N = A->n;
incx = ldb;
incy = ldc;
ldx = 1;
ldy = 1;
}
@@ -381,76 +492,8 @@ spm_zspmm( spm_side_t side,
return SPM_SUCCESS;
}
{
args.follow_x = 0;
args.baseval = spmFindBase( A );
args.n = A->n;
args.nnz = A->nnz;
args.alpha = alpha;
args.rowptr = A->rowptr;
args.colptr = A->colptr;
args.values = A->values;
args.x = B;
args.incx = incx;
args.y = C;
args.incy = incy;
args.conj_fct = __fct_id;
args.loop_fct = NULL;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( ( transA == SpmConjTrans ) ||
( A->mtxtype == SpmHermitian ) )
{
args.conj_fct = __fct_conj;
}
#endif
switch( A->fmttype ) {
case SpmCSC:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA == SpmNoTrans)) ||
((side == SpmRight) && (transA != SpmNoTrans)) )
{
args.follow_x = 1;
}
else {
args.follow_x = 0;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args.follow_x = 1;
}
else {
args.follow_x = 0;
}
args.colptr = A->rowptr;
args.rowptr = A->colptr;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( ((side == SpmLeft) && (transA != SpmNoTrans)) ||
((side == SpmRight) && (transA == SpmNoTrans)) )
{
args.colptr = A->rowptr;
args.rowptr = A->colptr;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
__spm_zmatvec_args_init( &args, side, transA,
alpha, A, B, ldb, C, ldc );
for( r=0; (r < N) && (rc == SPM_SUCCESS); r++ ) {
args.x = B + r * ldx;
@@ -528,59 +571,8 @@ spm_zspmv( spm_trans_t trans,
return SPM_SUCCESS;
}
{
args.follow_x = 0;
args.baseval = spmFindBase( A );
args.n = A->n;
args.nnz = A->nnz;
args.alpha = alpha;
args.rowptr = A->rowptr;
args.colptr = A->colptr;
args.values = A->values;
args.x = x;
args.incx = incx;
args.y = y;
args.incy = incy;
args.conj_fct = __fct_id;
args.loop_fct = NULL;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
if ( ( trans == SpmConjTrans ) ||
( A->mtxtype == SpmHermitian ) )
{
args.conj_fct = __fct_conj;
}
#endif
switch( A->fmttype ) {
case SpmCSC:
{
args.follow_x = (trans == SpmNoTrans) ? 1 : 0;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
}
break;
case SpmCSR:
{
/* Switch pointers and side to get the correct behaviour */
args.follow_x = (trans == SpmNoTrans) ? 0 : 1;
args.colptr = A->rowptr;
args.rowptr = A->colptr;
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
}
break;
case SpmIJV:
{
if ( trans != SpmNoTrans ) {
args.colptr = A->rowptr;
args.rowptr = A->colptr;
}
args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv;
}
break;
default:
return SPM_ERR_BADPARAMETER;
}
__spm_zmatvec_args_init( &args, SpmLeft, trans,
alpha, A, x, incx, y, incy );
rc = args.loop_fct( &args );
Loading