From 42ab5aee65c8c2c48af7a720aeaa5cbe9da63697 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Fri, 11 Nov 2016 22:51:36 +0100
Subject: [PATCH] Make expand working with variadic dofs

---
 z_spm_expand.c | 466 ++++++++++++++++++++++++++++---------------------
 1 file changed, 263 insertions(+), 203 deletions(-)

diff --git a/z_spm_expand.c b/z_spm_expand.c
index d9d982de..e2570aed 100644
--- a/z_spm_expand.c
+++ b/z_spm_expand.c
@@ -17,18 +17,14 @@
 #include "spm.h"
 #include "z_spm.h"
 
-/**
- * TODO: This function is incorrect
- */
 pastix_spm_t *
-z_spmExpand(const pastix_spm_t *spm)
+z_spmCSCExpand(const pastix_spm_t *spm)
 {
     pastix_spm_t       *newspm;
-    pastix_int_t        i, j, k, ii, jj, dof, dofi, dofj, col, row, baseval, cpt;
+    pastix_int_t        i, j, k, ii, jj, dofi, dofj, col, row, baseval, lda;
+    pastix_int_t        diag, height;
     pastix_int_t       *newcol, *newrow, *oldcol, *oldrow, *dofs;
-#if !defined(PRECISION_p)
-    pastix_complex64_t *newval, *oldval;
-#endif
+    pastix_complex64_t *newval, *oldval, *oldval2;
 
     if ( spm->dof == 1 ) {
         return (pastix_spm_t*)spm;
@@ -40,240 +36,289 @@ z_spmExpand(const pastix_spm_t *spm)
     }
 
     newspm = malloc( sizeof(pastix_spm_t) );
-    spmInit( newspm );
+    memcpy( newspm, spm, sizeof(pastix_spm_t) );
 
     baseval = spmFindBase( spm );
     oldcol = spm->colptr;
     oldrow = spm->rowptr;
     dofs   = spm->dofs;
 #if !defined(PRECISION_p)
-    oldval = (pastix_complex64_t*)(spm->values);
+    oldval = oldval2 = (pastix_complex64_t*)(spm->values);
 #endif
 
-    switch(spm->fmttype)
+    newspm->n = spm->nexp;
+    newspm->colptr = newcol = malloc(sizeof(pastix_int_t)*(spm->nexp+1));
+
+    /**
+     * First loop to compute the new colptr
+     */
+    *newcol = baseval;
+    for(j=0; j<spm->n; j++, oldcol++)
     {
-    case PastixCSC:
-        newspm->colptr = newcol = malloc(sizeof(pastix_int_t)*spm->nexp);
-        newspm->rowptr = newrow = malloc(sizeof(pastix_int_t)*spm->nnzexp);
+        diag = 0;
+        dofj = (spm->dof > 0 ) ? spm->dof : dofs[j+1] - dofs[j];
+
+        /* Sum the heights of the elements in the column */
+        newcol[1] = newcol[0];
+        for(k=oldcol[0]; k<oldcol[1]; k++)
+        {
+            i = oldrow[k-baseval] - baseval;
+            dofi = (spm->dof > 0 ) ? spm->dof : dofs[i+1] - dofs[i];
+            newcol[1] += dofi;
+
+            diag = (diag || (i == j));
+        }
+
+        diag = (diag & (spm->mtxtype != PastixGeneral));
+        height = newcol[1] - newcol[0];
+        newcol++;
+
+        /* Add extra columns */
+        for(jj=1; jj<dofj; jj++, newcol++)
+        {
+            newcol[1] = newcol[0] + height;
+
+            if ( diag ) {
+                newcol[1] -= jj;
+            }
+        }
+    }
+    assert( ((spm->mtxtype == PastixGeneral) && ((newcol[0]-baseval) == spm->nnzexp)) ||
+            ((spm->mtxtype != PastixGeneral) && ((newcol[0]-baseval) <= spm->nnzexp)) );
+
+    newspm->nnz = newcol[0] - baseval;
+    newspm->rowptr = newrow = malloc(sizeof(pastix_int_t)*newspm->nnz);
 #if !defined(PRECISION_p)
-        newspm->values  = newval = malloc(sizeof(pastix_complex64_t)*spm->nnzexp);
+    newspm->values = newval = malloc(sizeof(pastix_complex64_t)*newspm->nnz);
 #endif
 
-        *newcol = 0; newcol++;
-        *newrow = 0;
+    /**
+     * Second loop to compute the new rowptr and valptr
+     */
+    oldcol = spm->colptr;
+    oldrow = spm->rowptr;
+    newcol = newspm->colptr;
+    for(j=0, col=0; j<spm->n; j++, oldcol++)
+    {
         /**
-         * Loop on col
+         * Backup current position in oldval because we will pick
+         * interleaved data inside the buffer
          */
+        lda = newcol[1] - newcol[0];
+        oldval2 = oldval;
+
         if ( spm->dof > 0 ) {
-            col = 0;
-            dof = spm->dof;
+            dofj = spm->dof;
+            assert( col == spm->dof * j );
+        }
+        else {
+            dofj = dofs[j+1] - dofs[j];
+            assert( col == (dofs[j] - baseval) );
+        }
+
+        for(jj=0; jj<dofj; jj++, col++, newcol++)
+        {
+            assert( ((spm->mtxtype == PastixGeneral) && (lda == (newcol[1] - newcol[0]))) ||
+                    ((spm->mtxtype != PastixGeneral) && (lda >= (newcol[1] - newcol[0]))) );
+
+            /* Move to the top of the column jj in element (oldcol[j],j) */
+            oldval = oldval2;
 
-            for(i=0; i<spm->n; i++, col+=dof)
+            for(k=oldcol[0]; k<oldcol[1]; k++)
             {
-                for(ii=0; ii<dof; ii++, newcol++)
-                {
-                    /**
-                     * Loop on rows
-                     */
-                    for(k=oldcol[i]; k<oldcol[i+1]; k++, row+=dof)
-                    {
-                        j = oldrow[k-baseval]-baseval;
+                i = oldrow[k-baseval] - baseval;
 
-                        for(jj=0; jj<dof; jj++, newrow++)
-                        {
-                            (*newcol)++;
-                            (*newrow) = row + jj + baseval;
+                if ( spm->dof > 0 ) {
+                    dofi = spm->dof;
+                    row  = spm->dof * i;
+                }
+                else {
+                    dofi = dofs[i+1] - dofs[i];
+                    row  = dofs[i] - baseval;
+                }
+
+                /* Move to the top of the jj column in the current element */
+                oldval += dofi * jj;
 
+                for(ii=0; ii<dofi; ii++, row++)
+                {
+                    if ( (spm->mtxtype == PastixGeneral) ||
+                         (i != j) ||
+                         ((i == j) && (row >= col)) )
+                    {
+                        (*newrow) = row + baseval; newrow++;
 #if !defined(PRECISION_p)
-                            if ( (spm->mtxtype != PastixGeneral) &&
-                                 (row + jj < col + ii) )
-                            {
-                                (*newval) = oldval[ cpt ];
-                                newval++;
-                            }
-                            cpt++;
+                        (*newval) = *oldval; newval++;
 #endif
-                        }
                     }
-                    (*newcol) += baseval;
+                    oldval++;
                 }
+                /* Move to the top of the next element */
+                oldval += dofi * (dofj-jj-1);
             }
         }
+    }
+
+    newspm->gN      = newspm->n;
+    newspm->gnnz    = newspm->nnz;
+
+    newspm->gNexp   = newspm->gN;
+    newspm->nexp    = newspm->n;
+    newspm->gnnzexp = newspm->gnnz;
+    newspm->nnzexp  = newspm->nnz;
+
+    newspm->dof     = 1;
+    newspm->dofs    = NULL;
+    newspm->layout  = PastixColMajor;
+
+    assert(spm->loc2glob == NULL);//to do
+
+    (void)newval;
+    return newspm;
+}
+
+pastix_spm_t *
+z_spmCSRExpand(const pastix_spm_t *spm)
+{
+    pastix_spm_t       *newspm;
+    pastix_int_t        i, j, k, ii, jj, dofi, dofj, col, row, baseval, lda;
+    pastix_int_t        diag, height;
+    pastix_int_t       *newcol, *newrow, *oldcol, *oldrow, *dofs;
+    pastix_complex64_t *newval, *oldval, *oldval2;
+
+    if ( spm->dof == 1 ) {
+        return (pastix_spm_t*)spm;
+    }
+
+    if ( spm->layout != PastixColMajor ) {
+        pastix_error_print( "Unsupported layout\n" );
+        return NULL;
+    }
+
+    newspm = malloc( sizeof(pastix_spm_t) );
+    spmInit( newspm );
+
+    baseval = spmFindBase( spm );
+    oldcol = spm->colptr;
+    oldrow = spm->rowptr;
+    dofs   = spm->dofs;
+#if !defined(PRECISION_p)
+    oldval = oldval2 = (pastix_complex64_t*)(spm->values);
+#endif
+
+    newspm->n = spm->nexp;
+    newspm->rowptr = newrow = malloc(sizeof(pastix_int_t)*(spm->nexp+1));
+
+    /**
+     * First loop to compute the new rowptr
+     */
+    *newrow = baseval;
+    for(i=0; i<spm->n; i++, oldrow++)
+    {
+        diag = 0;
+        dofi = (spm->dof > 0 ) ? spm->dof : dofs[i+1] - dofs[i];
+
+        /* Sum the heights of the elements in the rowumn */
+        newrow[1] = newrow[0];
+        for(k=oldrow[0]; k<oldrow[1]; k++)
+        {
+            j = oldcol[k-baseval] - baseval;
+            dofj = (spm->dof > 0 ) ? spm->dof : dofs[j+1] - dofs[j];
+            newrow[1] += dofj;
+
+            diag = (diag || (i == j));
+        }
+
+        diag = (diag & (spm->mtxtype != PastixGeneral));
+        height = newrow[1] - newrow[0];
+        newrow++;
+
+        /* Add extra rowumns */
+        for(ii=1; ii<dofi; ii++, newrow++)
+        {
+            newrow[1] = newrow[0] + height;
+
+            if ( diag ) {
+                newrow[1] -= ii;
+            }
+        }
+    }
+    assert( ((spm->mtxtype == PastixGeneral) && ((newrow[0]-baseval) == spm->nnzexp)) ||
+            ((spm->mtxtype != PastixGeneral) && ((newrow[0]-baseval) <= spm->nnzexp)) );
+
+    newspm->nnz = newrow[0] - baseval;
+    newspm->colptr = newcol = malloc(sizeof(pastix_int_t)*newspm->nnz);
+#if !defined(PRECISION_p)
+    newspm->values = newval = malloc(sizeof(pastix_complex64_t)*newspm->nnz);
+#endif
+
+    /**
+     * Second loop to compute the new colptr and valptr
+     */
+    oldrow = spm->rowptr;
+    newrow = newspm->rowptr;
+    for(i=0, row=0; i<spm->n; i++, oldrow++)
+    {
+        /**
+         * Backup current position in oldval because we will pick
+         * interleaved data inside the buffer
+         */
+        lda = newrow[1] - newrow[0];
+        oldval2 = oldval;
+
+        if ( spm->dof > 0 ) {
+            dofi = spm->dof;
+            assert( row == spm->dof * i );
+        }
         else {
-            for(i=0; i<spm->n; i++)
+            dofi = dofs[i+1] - dofs[i];
+            assert( row == dofs[i] - baseval );
+        }
+
+        for(ii=0; ii<dofi; ii++, row++, newrow++)
+        {
+            assert( ((spm->mtxtype == PastixGeneral) && (lda == (newrow[1] - newrow[0]))) ||
+                    ((spm->mtxtype != PastixGeneral) && (lda >= (newrow[1] - newrow[0]))) );
+
+            /* Move to the top of the rowumn ii in element (oldrow[j],j) */
+            oldval = oldval2 + ii;
+
+            for(k=oldrow[0]; k<oldrow[1]; k++)
             {
-                col  = dofs[i];
-                dofi = dofs[i+1] - dofs[i];
+                j = oldcol[k-baseval] - baseval;
 
-                for(ii=0; ii<dofi; ii++, newcol++)
+                if ( spm->dof > 0 ) {
+                    dofj = spm->dof;
+                    col  = spm->dof * j;
+                }
+                else {
+                    dofj = dofs[j+1] - dofs[j];
+                    col  = dofs[j] - baseval;
+                }
+
+                for(jj=0; jj<dofj; jj++, col++)
                 {
-                    /**
-                     * Loop on rows
-                     */
-                    for(k=oldcol[i]; k<oldcol[i+1]; k++)
+                    if ( (spm->mtxtype == PastixGeneral) || ((i == j) && (col >= row)) )
                     {
-                        j = oldrow[k-baseval]-baseval;
-                        row  = dofs[j];
-                        dofj = dofs[j+1] - dofs[j];
-
-                        for(jj=0; jj<dofj; jj++, newrow++)
-                        {
-                            (*newcol)++;
-                            (*newrow) = row + jj + baseval;
-
+                        (*newcol) = col + baseval; newcol++;
 #if !defined(PRECISION_p)
-                            if ( (spm->mtxtype != PastixGeneral) &&
-                                 (row + jj < col + ii) )
-                            {
-                                (*newval) = oldval[ cpt ];
-                                newval++;
-                            }
-                            cpt++;
+                        (*newval) = *oldval; newval++;
 #endif
-                        }
                     }
-                    (*newcol) += baseval;
+                    oldval += dofi;
                 }
             }
         }
-    break;
-    case PastixCSR:
-    case PastixIJV:
-        free( newspm );
-        return NULL;
+        /* Move to the top of the next row of elements */
+        oldval -= (dofi-1);
     }
-        /*     for(i=0; i<spm->nexp; i++) */
-    /*     { */
-    /*         new_col[i+1]+=new_col[i]; */
-    /*     } */
-
-    /*     cpt = 0; */
-    /*     for(i=0; i < spm->n;i++) */
-    /*     { */
-    /*         col  = ( spm->dof > 0 ) ? i        : dofs[i]; */
-    /*         dofi = ( spm->dof > 0 ) ? spm->dof : dofs[i+1] - dofs[i]; */
-    /*         for(k=spm->colptr[i]-baseval ; k<spm->colptr[i+1]-baseval ;k++) */
-    /*         { */
-    /*             j = spm->rowptr[k] - baseval; */
-    /*             row  = ( spm->dof > 0 ) ? j        : dofs[j]; */
-    /*             dofj = ( spm->dof > 0 ) ? spm->dof : dofs[j+1] - dofs[j]; */
-    /*             for(ii=0;ii < dofi; ii++) */
-    /*             { */
-    /*                 for(jj=0;jj < dofj ; jj++) */
-    /*                 { */
-    /*                     new_vals[new_col[col+ii]] = vals[cpt]; */
-    /*                     new_row[new_col[col+ii]]  = row + jj + baseval; */
-    /*                     new_col[col+ii]++; */
-    /*                     cpt++; */
-    /*                 } */
-    /*             } */
-    /*         } */
-    /*     } */
-
-    /*     { */
-    /*         int tmp; */
-    /*         int tmp1 = 0; */
-    /*         for(i=0; i<spm->nexp; i++) */
-    /*         { */
-    /*             tmp = new_col[i]; */
-    /*             new_col[i] = tmp1+baseval; */
-    /*             tmp1 = tmp; */
-    /*         } */
-    /*         new_col[i] += baseval; */
-    /*     } */
-    /*     spm->gN   = spm->gNexp; */
-    /*     spm->n    = spm->nexp; */
-    /*     spm->gnnz = spm->gnnzexp; */
-    /*     spm->nnz  = spm->nnzexp; */
-
-    /*     spm->dof      = 1; */
-    /*     spm->dofs     = NULL; */
-    /*     spm->layout   = PastixColMajor; */
-
-    /*     spm->colptr   = new_col; */
-    /*     spm->rowptr   = new_row; */
-    /*     //spm->loc2glob = NULL; // ? */
-    /*     spm->values   = new_vals; */
-    /*     break; */
-
-    /* case PastixSymmetric: */
-    /*     for(i=0;i<spm->n ; i++) */
-    /*     { */
-    /*         col  = ( spm->dof > 0 ) ? i        : dofs[i]; */
-    /*         dofi = ( spm->dof > 0 ) ? spm->dof : dofs[i+1] - dofs[i]; */
-    /*         for(k=spm->colptr[i]-baseval; k<spm->colptr[i+1]-baseval; k++) */
-    /*         { */
-    /*             j = spm->rowptr[k]-baseval; */
-    /*             row  = ( spm->dof > 0 ) ? j        : dofs[j]; */
-    /*             dofj = ( spm->dof > 0 ) ? spm->dof : dofs[j+1] - dofs[j]; */
-    /*             for(ii=0; ii<dofi; ii++) */
-    /*             { */
-    /*                 for(jj=0; jj<dofj; jj++) */
-    /*                 { */
-    /*                     if( i != j ) */
-    /*                         new_col[col+ii+1] +=  1; */
-    /*                     else */
-    /*                         if(ii <= jj ) */
-    /*                             new_col[col+ii+1] += 1; */
-    /*                 } */
-    /*             } */
-    /*         } */
-    /*     } */
-    /*     for(i=0; i<spm->nexp; i++) */
-    /*     { */
-    /*         new_col[i+1] += new_col[i]; */
-    /*     } */
-    /*     pastix_int_t nnz = new_col[spm->nexp]; */
-    /*     new_row  = malloc(sizeof(pastix_int_t)*nnz); */
-    /*     new_vals = malloc(sizeof(pastix_complex64_t)*nnz); */
-
-    /*     cpt = 0; */
-    /*     for(i=0; i < spm->n;i++) */
-    /*     { */
-    /*         col  = ( spm->dof > 0 ) ? i        : dofs[i]; */
-    /*         dofi = ( spm->dof > 0 ) ? spm->dof : dofs[i+1] - dofs[i]; */
-    /*         for(k=spm->colptr[i]-baseval ; k<spm->colptr[i+1]-baseval ;k++) */
-    /*         { */
-    /*             j = spm->rowptr[k] - baseval; */
-    /*             row  = ( spm->dof > 0 ) ? j        : dofs[j]; */
-    /*             dofj = ( spm->dof > 0 ) ? spm->dof : dofs[j+1] - dofs[j]; */
-    /*             for(ii=0;ii < dofi; ii++) */
-    /*             { */
-    /*                 for(jj=0;jj < dofj ; jj++) */
-    /*                 { */
-    /*                     if( i == j ) */
-    /*                     { */
-    /*                         if ( ii <= jj ) */
-    /*                         { */
-    /*                             /\* diagonal dominant for spd matrix */
-    /*                             if( ii == jj) */
-    /*                                 new_vals[new_col[col+ii]] = 2*vals[cpt]; */
-    /*                              else */
-    /*                             *\/ */
-    /*                             new_vals[new_col[col+ii]] = vals[cpt]; */
-    /*                             new_row[new_col[col+ii]]  = row + jj + baseval; */
-    /*                             new_col[col+ii]++; */
-    /*                         } */
-    /*                     } */
-    /*                     else */
-    /*                     { */
-    /*                         new_vals[new_col[col+ii]] = vals[cpt]; */
-    /*                         new_row[new_col[col+ii]]  = row + jj + baseval; */
-    /*                         new_col[col+ii]++; */
-
-    /*                     } */
-    /*                     cpt++; */
-    /*                 } */
-    /*             } */
-            /* } */
-    /* } */
 
     newspm->gN      = spm->gNexp;
     newspm->n       = spm->nexp;
-    newspm->gnnz    = spm->gnnzexp;
-    newspm->nnz     = spm->nnzexp;
-    newspm->gnnzexp = spm->gnnzexp;
-    newspm->nnzexp  = spm->nnzexp;
+    newspm->gNexp   = newspm->gN;
+    newspm->nexp    = newspm->n;
+    newspm->gnnzexp = newspm->gnnz;
+    newspm->nnzexp  = newspm->nnz;
 
     newspm->dof      = 1;
     newspm->dofs     = NULL;
@@ -281,6 +326,21 @@ z_spmExpand(const pastix_spm_t *spm)
 
     assert(spm->loc2glob == NULL);//to do
 
-    (void)col; (void)cpt;
+    (void)newval;
     return newspm;
 }
+
+pastix_spm_t *
+z_spmExpand( const pastix_spm_t *spm )
+{
+    switch (spm->fmttype) {
+    case PastixCSC:
+        return z_spmCSCExpand( spm );
+    case PastixCSR:
+        return z_spmCSRExpand( spm );
+    case PastixIJV:
+        return NULL;//z_spmIJVExpand( spm );
+    }
+    return NULL;
+}
+
-- 
GitLab