From 747c7935567f2c953a39b224cb116c9a924544eb Mon Sep 17 00:00:00 2001
From: Guillaume Sylvand <guillaume.sylvand@airbus.com>
Date: Tue, 20 Sep 2016 17:39:44 +0000
Subject: [PATCH] Add possibility to use z/cgemm3m for complex mat-mat products

This routine, available in MKL, does a product in 6n^3 ops instead of 8n^3
but is interesting only for "large enough" matrices (to be tested...)
Potentially, we gain 25 % in all complex computations.
It could be interesting to look for it / implement it in cuda.

!!! Note that the flop counters are not updated         !!!
!!! In C/Z accuracy, most flops counter should be x0.75 !!!

IT is OFF by default
It is activated with MORSE_Enable(MORSE_GEMM3M)
In the timing routines, it is activated with --gemm3m
---
 control/context.c             | 13 +++++++++++++
 coreblas/compute/core_zgemm.c | 12 ++++++++++++
 include/morse_constants.h     |  1 +
 include/morse_struct.h        |  1 +
 timing/timing.c               |  9 +++++++++
 timing/timing.h               |  1 +
 6 files changed, 37 insertions(+)

diff --git a/control/context.c b/control/context.c
index 38d892036..91b5434aa 100644
--- a/control/context.c
+++ b/control/context.c
@@ -83,6 +83,7 @@ MORSE_context_t *morse_context_create()
     morse->parallel_enabled     = MORSE_FALSE;
     morse->profiling_enabled    = MORSE_FALSE;
     morse->progress_enabled     = MORSE_FALSE;
+    morse->gemm3m_enabled       = MORSE_FALSE;
 
     morse->householder        = MORSE_FLAT_HOUSEHOLDER;
     morse->translation        = MORSE_OUTOFPLACE;
@@ -132,6 +133,7 @@ int morse_context_destroy(){
  *          @arg MORSE_AUTOTUNING autotuning for tile size and inner block size.
  *          @arg MORSE_PROFILING_MODE  activate profiling of kernels
  *          @arg MORSE_PROGRESS  activate progress indicator
+ *          @arg MORSE_GEMM3M  Use z/cgemm3m for complexe matrix-matrix products
  *
  *******************************************************************************
  *
@@ -166,6 +168,13 @@ int MORSE_Enable(MORSE_enum option)
         case MORSE_PROGRESS:
             morse->progress_enabled = MORSE_TRUE;
             break;
+        case MORSE_GEMM3M:
+#ifdef CBLAS_HAS_ZGEMM3M
+            morse->gemm3m_enabled = MORSE_TRUE;
+#else
+            morse_error("MORSE_Enable", "cannot enable GEMM3M (not available in cblas)");
+#endif
+            break;
         /* case MORSE_PARALLEL: */
         /*     morse->parallel_enabled = MORSE_TRUE; */
         /*     break; */
@@ -197,6 +206,7 @@ int MORSE_Enable(MORSE_enum option)
  *          @arg MORSE_AUTOTUNING autotuning for tile size and inner block size.
  *          @arg MORSE_PROFILING_MODE  deactivate profiling of kernels
  *          @arg MORSE_PROGRESS  deactivate progress indicator
+ *          @arg MORSE_GEMM3M  Use z/cgemm3m for complexe matrix-matrix products
  *
  *******************************************************************************
  *
@@ -230,6 +240,9 @@ int MORSE_Disable(MORSE_enum option)
         case MORSE_PROGRESS:
             morse->progress_enabled = MORSE_FALSE;
             break;
+        case MORSE_GEMM3M:
+            morse->gemm3m_enabled = MORSE_FALSE;
+            break;
         case MORSE_PARALLEL_MODE:
             morse->parallel_enabled = MORSE_FALSE;
             break;
diff --git a/coreblas/compute/core_zgemm.c b/coreblas/compute/core_zgemm.c
index fe98b80a9..b5b7e4d4b 100644
--- a/coreblas/compute/core_zgemm.c
+++ b/coreblas/compute/core_zgemm.c
@@ -42,6 +42,18 @@ void CORE_zgemm(MORSE_enum transA, int transB,
                                           const MORSE_Complex64_t *B, int LDB,
                 MORSE_Complex64_t beta, MORSE_Complex64_t *C, int LDC)
 {
+#ifdef CBLAS_HAS_ZGEMM3M
+  MORSE_context_t *morse = morse_context_self();
+  if (morse->gemm3m_enabled)
+    cblas_zgemm3m(
+        CblasColMajor,
+        (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
+        M, N, K,
+        CBLAS_SADDR(alpha), A, LDA,
+        B, LDB,
+        CBLAS_SADDR(beta), C, LDC);
+  else
+#endif
     cblas_zgemm(
         CblasColMajor,
         (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
diff --git a/include/morse_constants.h b/include/morse_constants.h
index 39e3c4605..8ffe90dc0 100644
--- a/include/morse_constants.h
+++ b/include/morse_constants.h
@@ -131,6 +131,7 @@
 #define MORSE_PARALLEL_MODE   6
 #define MORSE_BOUND           7
 #define MORSE_PROGRESS        8
+#define MORSE_GEMM3M          9
 
 /** ****************************************************************************
  *  MORSE constants - configuration parameters
diff --git a/include/morse_struct.h b/include/morse_struct.h
index 2958ce3d0..397d96e1e 100644
--- a/include/morse_struct.h
+++ b/include/morse_struct.h
@@ -132,6 +132,7 @@ typedef struct morse_context_s {
     MORSE_bool         parallel_enabled;
     MORSE_bool         profiling_enabled;
     MORSE_bool         progress_enabled;
+    MORSE_bool         gemm3m_enabled;
 
     MORSE_enum         householder;        // "domino" (flat) or tree-based (reduction) Householder
     MORSE_enum         translation;        // In place or Out of place layout conversion
diff --git a/timing/timing.c b/timing/timing.c
index 796441d6b..1c83dbd16 100644
--- a/timing/timing.c
+++ b/timing/timing.c
@@ -349,6 +349,7 @@ show_help(char *prog_name) {
             "  --[a]sync        Enable/Disable synchronous calls in wrapper function such as POTRI. (default: async)\n"
             "  --[no]check      Check result (default: nocheck)\n"
             "  --[no]progress   Display progress indicator (default: noprogress)\n"
+            "  --[no]gemm3m     Use gemm3m complex method (default: nogemm3m)\n"
             "  --[no]inv        Check on inverse (default: noinv)\n"
             "  --[no]warmup     Perform a warmup run to pre-load libraries (default: warmup)\n"
             "  --[no]trace      Enable/Disable trace generation (default: notrace)\n"
@@ -487,6 +488,7 @@ main(int argc, char *argv[]) {
     iparam[IPARAM_NMPI          ] = 1;
     iparam[IPARAM_P             ] = 1;
     iparam[IPARAM_Q             ] = 1;
+    iparam[IPARAM_GEMM3M        ] = 0;
     iparam[IPARAM_PROGRESS      ] = 0;
     iparam[IPARAM_PROFILE       ] = 0;
     iparam[IPARAM_PRINT_ERRORS  ] = 0;
@@ -526,6 +528,10 @@ main(int argc, char *argv[]) {
             iparam[IPARAM_TRACE] = 1;
         } else if (startswith( argv[i], "--notrace" )) {
             iparam[IPARAM_TRACE] = 0;
+        } else if (startswith( argv[i], "--gemm3m" )) {
+            iparam[IPARAM_GEMM3M] = 1;
+        } else if (startswith( argv[i], "--nogemm3m" )) {
+            iparam[IPARAM_GEMM3M] = 0;
         } else if (startswith( argv[i], "--progress" )) {
             iparam[IPARAM_PROGRESS] = 1;
         } else if (startswith( argv[i], "--noprogress" )) {
@@ -637,6 +643,9 @@ main(int argc, char *argv[]) {
     if (iparam[IPARAM_PROGRESS] == 1)
         MORSE_Enable(MORSE_PROGRESS);
 
+    if (iparam[IPARAM_GEMM3M] == 1)
+        MORSE_Enable(MORSE_GEMM3M);
+
 #if defined(CHAMELEON_USE_MPI)
     MORSE_Comm_size( &nbnode );
     iparam[IPARAM_NMPI] = nbnode;
diff --git a/timing/timing.h b/timing/timing.h
index 4f8771bb6..a1636491c 100644
--- a/timing/timing.h
+++ b/timing/timing.h
@@ -48,6 +48,7 @@ enum iparam_timing {
     IPARAM_Q,              /* Parameter for 2D cyclic distribution       */
 
     IPARAM_PROGRESS,       /* Use a progress indicator during computations */
+    IPARAM_GEMM3M,         /* Use GEMM3M for complex matrix vector products */
     /* Added for StarPU version */
     IPARAM_PROFILE,
     IPARAM_PRINT_ERRORS,
-- 
GitLab