From f7ad22802abf16ee8f5b9902bb11150e79a1e7b2 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Wed, 31 Aug 2022 12:00:13 +0200
Subject: [PATCH] pzhemm: Added the first version of pzhemm_Astat

---
 compute/pzhemm.c | 251 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 251 insertions(+)

diff --git a/compute/pzhemm.c b/compute/pzhemm.c
index 75f1ab66b..12b61a707 100644
--- a/compute/pzhemm.c
+++ b/compute/pzhemm.c
@@ -30,6 +30,257 @@
 #define WA( _m_, _n_ ) WA, (_m_), (_n_)
 #define WB( _m_, _n_ ) WB, (_m_), (_n_)
 
+/**
+ *  Parallel tile matrix-matrix multiplication.
+ *  Generic algorithm for any data distribution with a stationnary A.
+ *
+ * Assuming A has been setup with a proper getrank function to account for symmetry
+ */
+static inline void
+chameleon_pzhemm_Astat( CHAM_context_t *chamctxt, cham_side_t side, cham_uplo_t uplo,
+                        CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
+                        CHAMELEON_Complex64_t beta,  CHAM_desc_t *C,
+                        RUNTIME_option_t *options )
+{
+    const CHAMELEON_Complex64_t zone = (CHAMELEON_Complex64_t)1.0;
+    RUNTIME_sequence_t *sequence = options->sequence;
+    int                 k, m, n, l, Am, An;
+    int                 tempmm, tempnn, tempkn, tempkm;
+    int                 myrank = RUNTIME_comm_rank( chamctxt );
+    int                 reduceC[ C->mt * C->nt ];
+
+    /* Set C tiles to redux mode */
+    for (n = 0; n < C->nt; n++) {
+        for (m = 0; m < C->mt; m++) {
+            reduceC[ n * C->mt + m ] = 0;
+
+            /* The node owns the C tile. */
+            if ( C->get_rankof( C(m, n) ) == myrank ) {
+                reduceC[ n * C->mt + m ] = 1;
+                RUNTIME_zgersum_set_methods( C(m, n) );
+                continue;
+            }
+
+            /*
+             * The node owns the A tile that will define the locality of the
+             * computations.
+             */
+            /* Select row or column based on side */
+            l = ( side == ChamLeft ) ? m : n;
+
+            if ( uplo == ChamLower ) {
+                for (k = 0; k < A->mt; k++) {
+                    Am = k;
+                    An = k;
+
+                    if (k < l) {
+                        Am = l;
+                    }
+                    else if (k > l) {
+                        An = l;
+                    }
+
+                    if ( A->get_rankof( A( Am, An ) ) == myrank ) {
+                        reduceC[ n * C->mt + m ] = 1;
+                        RUNTIME_zgersum_set_methods( C(m, n) );
+                        break;
+                    }
+                }
+            }
+            else {
+                for (k = 0; k < A->mt; k++) {
+                    Am = k;
+                    An = k;
+
+                    if (k < l) {
+                        An = l;
+                    }
+                    else if (k > l) {
+                        Am = l;
+                    }
+
+                    if ( A->get_rankof( A( Am, An ) ) == myrank ) {
+                        reduceC[ n * C->mt + m ] = 1;
+                        RUNTIME_zgersum_set_methods( C(m, n) );
+                        break;
+                    }
+                }
+            }
+        }
+    }
+
+    for(n = 0; n < C->nt; n++) {
+        tempnn = n == C->nt-1 ? C->n-n*C->nb : C->nb;
+        for(m = 0; m < C->mt; m++) {
+            tempmm = m == C->mt-1 ? C->m-m*C->mb : C->mb;
+
+            /* Scale C */
+            options->forcesub = 0;
+            INSERT_TASK_zlascal( options, ChamUpperLower, tempmm, tempnn, C->mb,
+                                 beta, C, m, n );
+            options->forcesub = reduceC[ n * C->mt + m ];
+
+            /*
+             *  ChamLeft / ChamLower
+             */
+            /* Select row or column based on side */
+            l = ( side == ChamLeft ) ? m : n;
+
+            if (side == ChamLeft) {
+                if (uplo == ChamLower) {
+                    for (k = 0; k < C->mt; k++) {
+                        tempkm = k == C->mt-1 ? C->m-k*C->mb : C->mb;
+
+                        if (k < m) {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkm, A->mb,
+                                alpha, A(m, k),  /* lda * K */
+                                       B(k, n),  /* ldb * Y */
+                                zone,  C(m, n)); /* ldc * Y */
+                        }
+                        else if (k == m) {
+                                INSERT_TASK_zhemm_Astat(
+                                    options,
+                                    side, uplo,
+                                    tempmm, tempnn, A->mb,
+                                    alpha, A(k, k),  /* ldak * X */
+                                           B(k, n),  /* ldb  * Y */
+                                    zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkm, A->mb,
+                                alpha, A(k, m),  /* ldak * X */
+                                       B(k, n),  /* ldb  * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                    }
+                }
+                /*
+                 *  ChamLeft / ChamUpper
+                 */
+                else {
+                    for (k = 0; k < C->mt; k++) {
+                        tempkm = k == C->mt-1 ? C->m-k*C->mb : C->mb;
+
+                        if (k < m) {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkm, A->mb,
+                                alpha, A(k, m),  /* ldak * X */
+                                       B(k, n),  /* ldb  * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else if (k == m) {
+                            INSERT_TASK_zhemm_Astat(
+                                options,
+                                side, uplo,
+                                tempmm, tempnn, A->mb,
+                                alpha, A(k, k),  /* ldak * K */
+                                       B(k, n),  /* ldb  * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkm, A->mb,
+                                alpha, A(m, k),  /* lda * K */
+                                       B(k, n),  /* ldb * Y */
+                                zone,  C(m, n)); /* ldc * Y */
+                        }
+                    }
+                }
+            }
+            /*
+             *  ChamRight / ChamLower
+             */
+            else {
+                if (uplo == ChamLower) {
+                    for (k = 0; k < C->nt; k++) {
+                        tempkn = k == C->nt-1 ? C->n-k*C->nb : C->nb;
+
+                        if (k < n) {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamTrans,
+                                tempmm, tempnn, tempkn, A->mb,
+                                alpha, B(m, k),  /* ldb * K */
+                                       A(n, k),  /* lda * K */
+                                zone,  C(m, n)); /* ldc * Y */
+                        }
+                        else if (k == n) {
+                            INSERT_TASK_zhemm_Astat(
+                                options,
+                                side, uplo,
+                                tempmm, tempnn, A->mb,
+                                alpha, A(k, k),  /* ldak * Y */
+                                       B(m, k),  /* ldb  * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkn, A->mb,
+                                alpha, B(m, k),  /* ldb  * K */
+                                       A(k, n),  /* ldak * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                    }
+                }
+                /*
+                 *  ChamRight / ChamUpper
+                 */
+                else {
+                    for (k = 0; k < C->nt; k++) {
+                        tempkn = k == C->nt-1 ? C->n-k*C->nb : C->nb;
+
+                        if (k < n) {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamNoTrans,
+                                tempmm, tempnn, tempkn, A->mb,
+                                alpha, B(m, k),  /* ldb  * K */
+                                       A(k, n),  /* ldak * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else if (k == n) {
+                            INSERT_TASK_zhemm_Astat(
+                                options,
+                                side, uplo,
+                                tempmm, tempnn, A->mb,
+                                alpha, A(k, k),  /* ldak * Y */
+                                       B(m, k),  /* ldb  * Y */
+                                zone,  C(m, n)); /* ldc  * Y */
+                        }
+                        else {
+                            INSERT_TASK_zgemm_Astat(
+                                options,
+                                ChamNoTrans, ChamTrans,
+                                tempmm, tempnn, tempkn, A->mb,
+                                alpha, B(m, k),  /* ldb * K */
+                                       A(n, k),  /* lda * K */
+                                zone,  C(m, n)); /* ldc * Y */
+                        }
+                    }
+                }
+            }
+
+            RUNTIME_zgersum_submit_tree( options, C(m, n) );
+            RUNTIME_data_flush( sequence, C(m, n) );
+        }
+    }
+    options->forcesub = 0;
+    (void)chamctxt;
+}
+
+
 /**
  *  Parallel tile hermitian matrix-matrix multiplication.
  *  SUMMA algorithm for 2D block-cyclic data distribution.
-- 
GitLab