From 7215afabd9508e66df6bf54b7ead787f6e7a06dc Mon Sep 17 00:00:00 2001
From: Raphael Boucherie <raphael.boucherie@inria.fr>
Date: Mon, 24 Apr 2017 18:14:21 +0200
Subject: [PATCH] all tests works for systolic and hqr

---
 include/libhqr.h          |  17 +++++
 src/libhqr.c              | 153 ++++++++++++++++++++------------------
 src/libhqr_systolic.c     |  40 +++++++---
 testings/testing_pivgen.c |   2 +-
 4 files changed, 128 insertions(+), 84 deletions(-)

diff --git a/include/libhqr.h b/include/libhqr.h
index 03056e2..9acbd02 100644
--- a/include/libhqr.h
+++ b/include/libhqr.h
@@ -89,6 +89,16 @@ typedef struct libhqr_tiledesc_s{
     int p;     /**< The number of nodes per column in the data distribution */
 } libhqr_tiledesc_t;
 
+
+typedef struct libhqr_tileinfo_s{
+    int type;
+    int currpiv;
+    int nextpiv;
+    int prevpiv;
+    int first_nextpiv;
+    int first_prevpiv;
+} libhqr_tileinfo_t;
+
 struct libhqr_tree_s;
 typedef struct libhqr_tree_s libhqr_tree_t;
 
@@ -190,6 +200,13 @@ int  libhqr_hqr_init( libhqr_tree_t *qrtree,
                       int type_llvl, int type_hlvl,
                       int a, int p, int domino, int tsrr );
 void libhqr_hqr_finalize( libhqr_tree_t *qrtree );
+
+
+void libhqr_matrix_init(libhqr_tree_t *qrtree, const libhqr_tree_t *qrtree_init);
+int rdmtx_gettype(const libhqr_tree_t *qrtree, int k, int m);
+int rdmtx_currpiv(const libhqr_tree_t *qrtree, int k, int m);
+int rdmtx_nextpiv(const libhqr_tree_t *qrtree, int k, int p, int m);
+int rdmtx_prevpiv(const libhqr_tree_t *qrtree, int k, int p, int m);
 void libhqr_matrix_finalize(libhqr_tree_t *qrtree);
 
 /*
diff --git a/src/libhqr.c b/src/libhqr.c
index f1442b4..2750739 100644
--- a/src/libhqr.c
+++ b/src/libhqr.c
@@ -108,15 +108,6 @@ struct hqr_args_s {
     int *perm;
 };
 
-typedef struct libhqr_tileinfo_s{
-    int type;
-    int currpiv;
-    int nextpiv;
-    int prevpiv;
-    int first_nextpiv;
-    int first_prevpiv;
-} libhqr_tileinfo_t;
-
 struct hqr_subpiv_s {
     /**
      * currpiv
@@ -179,14 +170,6 @@ static void hqr_low_greedy_init(   hqr_subpiv_t *arg, int minMN);
 static void hqr_low_binary_init(   hqr_subpiv_t *arg);
 static void hqr_low_fibonacci_init(hqr_subpiv_t *arg, int minMN);
 
-/* Stocking matrix info */
-void libhqr_matrix_init(libhqr_tree_t *qrtree, const libhqr_tree_t *qrtree_init);
-
-/* Function for getting the info on the matrix*/
-static int rdmtx_gettype(const libhqr_tree_t *qrtree, int k, int m);
-static int rdmtx_currpiv(const libhqr_tree_t *qrtree, int k, int m);
-static int rdmtx_nextpiv(const libhqr_tree_t *qrtree, int k, int p, int m);
-static int rdmtx_prevpiv(const libhqr_tree_t *qrtree, int k, int p, int m);
 
 /****************************************************
  * Reading functions
@@ -198,14 +181,14 @@ static int rdmtx_prevpiv(const libhqr_tree_t *qrtree, int k, int p, int m);
  *    m      - line anhilated
  */
 
-static int rdmtx_gettype(const libhqr_tree_t *qrtree, int k, int m){
+int  rdmtx_gettype(const libhqr_tree_t *qrtree, int k, int m){
     int id;
     libhqr_tileinfo_t *arg = (libhqr_tileinfo_t*)(qrtree->args);
     id = (k * qrtree->mt) + m;
     return arg[id].type;
 }
 
-static int rdmtx_currpiv(const libhqr_tree_t *qrtree, int k, int m){
+int rdmtx_currpiv(const libhqr_tree_t *qrtree, int k, int m){
     int id, perm_m, p, a;
     libhqr_tileinfo_t *arg = (libhqr_tileinfo_t*)(qrtree->args);
     perm_m = m;
@@ -222,7 +205,7 @@ static int rdmtx_currpiv(const libhqr_tree_t *qrtree, int k, int m){
  *    p - line used as pivot
  */
 
-static int rdmtx_nextpiv(const libhqr_tree_t *qrtree, int k, int p, int m){
+int rdmtx_nextpiv(const libhqr_tree_t *qrtree, int k, int p, int m){
     int id;
     libhqr_tileinfo_t *arg = (libhqr_tileinfo_t*)(qrtree->args);
     int gmt = qrtree->mt;
@@ -238,7 +221,7 @@ static int rdmtx_nextpiv(const libhqr_tree_t *qrtree, int k, int p, int m){
     }
 }
 
-static int rdmtx_prevpiv(const libhqr_tree_t *qrtree, int k, int p, int m){
+int rdmtx_prevpiv(const libhqr_tree_t *qrtree, int k, int p, int m){
     int id;
     libhqr_tileinfo_t *arg = (libhqr_tileinfo_t*)(qrtree->args);
     int gmt = qrtree->mt;
@@ -2732,6 +2715,7 @@ libhqr_svd_init( libhqr_tree_t *qrtree,
 {
     int low_mt, minMN, a = -1;
     hqr_args_t *arg;
+    libhqr_tree_t qrtree_init;
 
     if (qrtree == NULL) {
         fprintf(stderr,"libhqr_svd_init, illegal value of qrtree");
@@ -2750,72 +2734,93 @@ libhqr_svd_init( libhqr_tree_t *qrtree,
     /* Compute parameters */
     p = libhqr_imax( p, 1 );
 
-    qrtree->getnbgeqrf = svd_getnbgeqrf;
-    qrtree->getm       = svd_getm;
-    qrtree->geti       = svd_geti;
-    qrtree->gettype    = svd_gettype;
-    qrtree->currpiv    = svd_currpiv;
-    qrtree->nextpiv    = svd_nextpiv;
-    qrtree->prevpiv    = svd_prevpiv;
+    /* Create a temporary qrtree structure based on the functions */
+    {
+        qrtree_init.getnbgeqrf = svd_getnbgeqrf;
+        qrtree_init.getm       = svd_getm;
+        qrtree_init.geti       = svd_geti;
+        qrtree_init.gettype    = svd_gettype;
+        qrtree_init.currpiv    = svd_currpiv;
+        qrtree_init.nextpiv    = svd_nextpiv;
+        qrtree_init.prevpiv    = svd_prevpiv;
 
-    qrtree->mt   = (trans == LIBHQR_QR) ? A->mt : A->nt;
-    qrtree->nt   = (trans == LIBHQR_QR) ? A->nt : A->mt;
+        qrtree_init.mt   = (trans == LIBHQR_QR) ? A->mt : A->nt;
+        qrtree_init.nt   = (trans == LIBHQR_QR) ? A->nt : A->mt;
 
-    qrtree->a    = a;
-    qrtree->p    = p;
-    qrtree->args = NULL;
+        qrtree_init.a    = a;
+        qrtree_init.p    = p;
+        qrtree_init.args = NULL;
 
-    arg = (hqr_args_t*) malloc( sizeof(hqr_args_t) );
-    arg->domino = 0;
-    arg->tsrr = 0;
-    arg->perm = NULL;
+        arg = (hqr_args_t*) malloc( sizeof(hqr_args_t) );
+        arg->domino = 0;
+        arg->tsrr = 0;
+        arg->perm = NULL;
 
-    arg->llvl = (hqr_subpiv_t*) malloc( sizeof(hqr_subpiv_t) );
-    arg->hlvl = NULL;
+        arg->llvl = (hqr_subpiv_t*) malloc( sizeof(hqr_subpiv_t) );
+        arg->hlvl = NULL;
 
-    minMN = libhqr_imin(A->mt, A->nt);
-    low_mt = (qrtree->mt + p - 1) / ( p );
+        minMN = libhqr_imin(A->mt, A->nt);
+        low_mt = (qrtree_init.mt + p - 1) / ( p );
 
-    arg->llvl->minMN  = minMN;
-    arg->llvl->ldd    = low_mt;
-    arg->llvl->a      = a;
-    arg->llvl->p      = p;
-    arg->llvl->domino = 0;
+        arg->llvl->minMN  = minMN;
+        arg->llvl->ldd    = low_mt;
+        arg->llvl->a      = a;
+        arg->llvl->p      = p;
+        arg->llvl->domino = 0;
 
-    svd_low_adaptiv_init(arg->llvl, qrtree->mt, qrtree->nt,
-                         nbthread_per_node * (A->nodes / p), ratio );
+        svd_low_adaptiv_init(arg->llvl, qrtree_init.mt, qrtree_init.nt,
+                             nbthread_per_node * (A->nodes / p), ratio );
 
-    if ( p > 1 ) {
-        arg->hlvl = (hqr_subpiv_t*) malloc( sizeof(hqr_subpiv_t) );
+        if ( p > 1 ) {
+            arg->hlvl = (hqr_subpiv_t*) malloc( sizeof(hqr_subpiv_t) );
 
-        arg->llvl->minMN  = minMN;
-        arg->hlvl->ldd    = qrtree->mt;
-        arg->hlvl->a      = a;
-        arg->hlvl->p      = p;
-        arg->hlvl->domino = 0;
+            arg->llvl->minMN  = minMN;
+            arg->hlvl->ldd    = qrtree_init.mt;
+            arg->hlvl->a      = a;
+            arg->hlvl->p      = p;
+            arg->hlvl->domino = 0;
 
-        switch( type_hlvl ) {
-        case LIBHQR_FLAT_TREE :
-            hqr_high_flat_init(arg->hlvl);
-            break;
-        case LIBHQR_GREEDY_TREE :
-            hqr_high_greedy_init(arg->hlvl, minMN);
-            break;
-        case LIBHQR_GREEDY1P_TREE :
-            hqr_high_greedy1p_init(arg->hlvl);
-            break;
-        case LIBHQR_BINARY_TREE :
-            hqr_high_binary_init(arg->hlvl);
-            break;
-        case LIBHQR_FIBONACCI_TREE :
-            hqr_high_fibonacci_init(arg->hlvl);
-            break;
-        default:
-            hqr_high_fibonacci_init(arg->hlvl);
+            switch( type_hlvl ) {
+            case LIBHQR_FLAT_TREE :
+                hqr_high_flat_init(arg->hlvl);
+                break;
+            case LIBHQR_GREEDY_TREE :
+                hqr_high_greedy_init(arg->hlvl, minMN);
+                break;
+            case LIBHQR_GREEDY1P_TREE :
+                hqr_high_greedy1p_init(arg->hlvl);
+                break;
+            case LIBHQR_BINARY_TREE :
+                hqr_high_binary_init(arg->hlvl);
+                break;
+            case LIBHQR_FIBONACCI_TREE :
+                hqr_high_fibonacci_init(arg->hlvl);
+                break;
+            default:
+                hqr_high_fibonacci_init(arg->hlvl);
+            }
         }
+
+        qrtree_init.args = (void*)arg;
+        hqr_genperm( &qrtree_init );
     }
 
-    qrtree->args = (void*)arg;
+    /* Initialize the final QR tree */
+    memcpy( qrtree, &qrtree_init, sizeof(libhqr_tree_t) );
+    qrtree->getnbgeqrf = svd_getnbgeqrf;
+    qrtree->getm       = svd_getm;
+    qrtree->geti       = svd_geti;
+    qrtree->gettype    = rdmtx_gettype;
+    qrtree->currpiv    = rdmtx_currpiv;
+    qrtree->nextpiv    = rdmtx_nextpiv;
+    qrtree->prevpiv    = rdmtx_prevpiv;
+    qrtree->args       = malloc( qrtree->mt * qrtree->nt * sizeof(libhqr_tileinfo_t) );
+
+    /* Initialize the matrix */
+    libhqr_matrix_init(qrtree, &qrtree_init);
+
+    /* Free the initial qrtree */
+    libhqr_hqr_finalize( &qrtree_init );
 
     return 0;
 }
diff --git a/src/libhqr_systolic.c b/src/libhqr_systolic.c
index 66153a3..d66d3bb 100644
--- a/src/libhqr_systolic.c
+++ b/src/libhqr_systolic.c
@@ -19,6 +19,7 @@
 #include <assert.h>
 #include <stdio.h>
 #include <stdlib.h>
+#include <string.h>
 
 #define PRINT_PIVGEN 0
 #ifdef PRINT_PIVGEN
@@ -352,6 +353,8 @@ libhqr_systolic_init( libhqr_tree_t *qrtree,
 		       libhqr_typefacto_e trans, libhqr_tiledesc_t *A,
 		       int p, int q )
 {
+    libhqr_tree_t qrtree_init;
+
     if (qrtree == NULL) {
 	fprintf(stderr, "libhqr_systolic_init, illegal value of qrtree");
 	return -1;
@@ -374,20 +377,39 @@ libhqr_systolic_init( libhqr_tree_t *qrtree,
 	return -5;
     }
 
+    /* Create a temporary qrtree structure based on the functions */
+    {
+        qrtree_init.getnbgeqrf = systolic_getnbgeqrf;
+        qrtree_init.getm       = systolic_getm;
+        qrtree_init.geti       = systolic_geti;
+        qrtree_init.gettype    = systolic_gettype;
+        qrtree_init.currpiv    = systolic_currpiv;
+        qrtree_init.nextpiv    = systolic_nextpiv;
+        qrtree_init.prevpiv    = systolic_prevpiv;
+
+        qrtree_init.mt   = (trans == LIBHQR_QR) ? A->mt : A->nt;
+        qrtree_init.nt   = (trans == LIBHQR_QR) ? A->nt : A->mt;
+
+        qrtree_init.a    = libhqr_imax( q, 1 );
+        qrtree_init.p    = libhqr_imax( p, 1 );
+        qrtree_init.args = NULL;
+    }
+
+    memcpy( qrtree, &qrtree_init, sizeof(libhqr_tree_t) );
     qrtree->getnbgeqrf = systolic_getnbgeqrf;
     qrtree->getm       = systolic_getm;
     qrtree->geti       = systolic_geti;
-    qrtree->gettype    = systolic_gettype;
-    qrtree->currpiv    = systolic_currpiv;
-    qrtree->nextpiv    = systolic_nextpiv;
-    qrtree->prevpiv    = systolic_prevpiv;
+    qrtree->gettype    = rdmtx_gettype;
+    qrtree->currpiv    = rdmtx_currpiv;
+    qrtree->nextpiv    = rdmtx_nextpiv;
+    qrtree->prevpiv    = rdmtx_prevpiv;
+    qrtree->args       = malloc( qrtree->mt * qrtree->nt * sizeof(libhqr_tileinfo_t) );
 
-    qrtree->mt   = (trans == LIBHQR_QR) ? A->mt : A->nt;
-    qrtree->nt   = (trans == LIBHQR_QR) ? A->nt : A->mt;
+    /*Initialize the matrix */
+    libhqr_matrix_init(qrtree, &qrtree_init);
 
-    qrtree->a    = libhqr_imax( q, 1 );
-    qrtree->p    = libhqr_imax( p, 1 );
-    qrtree->args = NULL;
+    /* Free the initial qrtree */
+    libhqr_systolic_finalize( &qrtree_init );
 
     return 0;
 }
diff --git a/testings/testing_pivgen.c b/testings/testing_pivgen.c
index ef8827e..3b80573 100644
--- a/testings/testing_pivgen.c
+++ b/testings/testing_pivgen.c
@@ -147,7 +147,7 @@ int main(int argc, char ** argv)
 	      ret |= 1;
 	    }
 
-	    libhqr_systolic_finalize( &qrtree );
+	    libhqr_matrix_finalize( &qrtree );
 
 	    done++;
 	    printf("\r%6d / %6d", done, todo);
-- 
GitLab