MatrixOperations.hxx 21.8 KB
Newer Older
GILLES Sebastien's avatar
GILLES Sebastien committed
1 2 3 4 5 6 7 8 9 10 11 12 13
/*!
//
// \file
//
//
// Created by Sebastien Gilles <sebastien.gilles@inria.fr> on the Fri, 30 Oct 2015 12:41:42 +0100
// Copyright (c) Inria. All rights reserved.
//
// \ingroup ThirdPartyGroup
// \addtogroup ThirdPartyGroup
// \{
*/

14

15 16
#ifndef MOREFEM_x_THIRD_PARTY_x_WRAPPERS_x_PETSC_x_MATRIX_x_MATRIX_OPERATIONS_HXX_
# define MOREFEM_x_THIRD_PARTY_x_WRAPPERS_x_PETSC_x_MATRIX_x_MATRIX_OPERATIONS_HXX_
17 18


19
namespace MoReFEM
20
{
21 22


23 24
    namespace Wrappers
    {
25 26


27 28
        namespace Petsc
        {
29 30


31
            template<class MatrixT>
32
            std::enable_if_t<std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value, void>
33 34 35
            MatMultTranspose(const MatrixT& matrix,
                             const Vector& v1,
                             Vector& v2,
36 37
                             const char* invoking_file, int invoking_line,
                             update_ghost do_update_ghost)
38
            {
39 40 41
                int error_code = ::MatMultTranspose(matrix.InternalForReadOnly(),
                                                    v1.InternalForReadOnly(),
                                                    v2.Internal());
42 43
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMultTranspose", invoking_file, invoking_line);
44

45
                v2.UpdateGhosts(invoking_file, invoking_line, do_update_ghost);
46
            }
47 48


49
            template<class MatrixT>
50
            std::enable_if_t<std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value, void>
51 52 53 54
            MatMultTransposeAdd(const MatrixT& matrix,
                                const Vector& v1,
                                const Vector& v2,
                                Vector& v3,
55 56
                                const char* invoking_file, int invoking_line,
                                update_ghost do_update_ghost)
57
            {
58 59 60 61
                int error_code = ::MatMultTransposeAdd(matrix.InternalForReadOnly(),
                                                       v1.InternalForReadOnly(),
                                                       v2.InternalForReadOnly(),
                                                       v3.Internal());
62 63
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMultTransposeAdd", invoking_file, invoking_line);
64

65
                v3.UpdateGhosts(invoking_file, invoking_line, do_update_ghost);
66
            }
67 68


69 70
            template
            <
GILLES Sebastien's avatar
GILLES Sebastien committed
71 72 73
                class MatrixT,
                class MatrixU,
                class MatrixV
74 75 76
            >
            std::enable_if_t
            <
77
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
78 79
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixV>::value,
80 81
                void
            >
82
            MatMatMult(const MatrixT& matrix1,
83 84
                       const MatrixU& matrix2,
                       MatrixV& out,
85 86
                       const char* invoking_file, int invoking_line,
                       DoReuseMatrix do_reuse_matrix)
87 88
            {
                Mat result;
89
                int error_code { };
90

91
                switch (do_reuse_matrix)
92 93 94 95
                {
                    case DoReuseMatrix::yes:
                    {
                        result = out.Internal();
96 97
                        error_code = ::MatMatMult(matrix1.InternalForReadOnly(),
                                                  matrix2.InternalForReadOnly(),
98 99 100 101 102 103 104
                                                  MAT_REUSE_MATRIX,
                                                  PETSC_DEFAULT,
                                                  &result);
                        break;
                    }
                    case DoReuseMatrix::no:
                    {
105 106
                        error_code = ::MatMatMult(matrix1.InternalForReadOnly(),
                                                  matrix2.InternalForReadOnly(),
107 108 109 110 111 112
                                                  MAT_INITIAL_MATRIX,
                                                  PETSC_DEFAULT,
                                                  &result);
                        out.SetFromPetscMat(result);
                        break;
                    }
113 114
                    case DoReuseMatrix::in_place:
                    {
115
                        static_cast<void>(error_code);
116
                        assert(false && "In place matrix option not supported for this function.");
117
                        exit(EXIT_FAILURE);
118 119
                        break;
                    }
120
                } // switch
121

122 123 124
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMatMult", invoking_file, invoking_line);
            }
125 126


127 128
            template
            <
GILLES Sebastien's avatar
GILLES Sebastien committed
129 130 131 132
                class MatrixT,
                class MatrixU,
                class MatrixV,
                class MatrixW
133 134 135
            >
            std::enable_if_t
            <
136
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
137 138 139
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixV>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixW>::value,
140 141
                void
            >
142
            MatMatMatMult(const MatrixT& matrix1,
143 144 145
                          const MatrixU& matrix2,
                          const MatrixV& matrix3,
                          MatrixW& out,
146 147
                          const char* invoking_file, int invoking_line,
                          DoReuseMatrix do_reuse_matrix)
148 149
            {
                Mat result;
150
                int error_code {};
151

152
                switch (do_reuse_matrix)
153 154 155 156
                {
                    case DoReuseMatrix::yes:
                    {
                        result = out.Internal();
157 158 159
                        error_code = ::MatMatMatMult(matrix1.InternalForReadOnly(),
                                                     matrix2.InternalForReadOnly(),
                                                     matrix3.InternalForReadOnly(),
160 161 162
                                                     MAT_REUSE_MATRIX,
                                                     PETSC_DEFAULT,
                                                     &result);
163 164


165 166 167 168
                        break;
                    }
                    case DoReuseMatrix::no:
                    {
169 170 171
                        error_code = ::MatMatMatMult(matrix1.InternalForReadOnly(),
                                                     matrix2.InternalForReadOnly(),
                                                     matrix3.InternalForReadOnly(),
172 173 174 175 176 177 178
                                                     MAT_INITIAL_MATRIX,
                                                     PETSC_DEFAULT,
                                                     &result);

                        out.SetFromPetscMat(result);
                        break;
                    }
179 180
                    case DoReuseMatrix::in_place:
                    {
181
                        static_cast<void>(error_code);
182
                        assert(false && "In place matrix option not supported for this function.");
183
                        exit(EXIT_FAILURE);
184 185
                        break;
                    }
186
                } // switch
187

188 189 190
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMatMatMult", invoking_file, invoking_line);
            }
191 192


193 194
            template
            <
195
                NonZeroPattern NonZeroPatternT,
196 197 198 199 200
                class MatrixT,
                class MatrixU
            >
            std::enable_if_t
            <
201 202
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value,
203 204
                void
            >
205
            AXPY(PetscScalar a,
206 207
                 const MatrixT& X,
                 MatrixU& Y,
208
                 const char* invoking_file, int invoking_line)
209
            {
210 211
                int error_code = ::MatAXPY(Y.Internal(),
                                           a,
212
                                           X.InternalForReadOnly(),
213
                                           Internal::Wrappers::Petsc::NonZeroPatternPetsc<NonZeroPatternT>());
214

215 216 217
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatAXPY", invoking_file, invoking_line);
            }
218 219


220 221 222 223
            template<class MatrixT>
            std::enable_if_t<std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value, void>
            MatShift(const PetscScalar a,
                     MatrixT& matrix,
224
                     const char* invoking_file, int invoking_line)
225 226 227 228 229
            {
                int error_code = ::MatShift(matrix.Internal(), a);
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatShift", invoking_file, invoking_line);
            }
230 231


232
            template<class MatrixT>
233
            std::enable_if_t<std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value, void>
234 235 236
            MatMult(const MatrixT& matrix,
                    const Vector& v1,
                    Vector& v2,
237 238
                    const char* invoking_file, int invoking_line,
                    update_ghost do_update_ghost)
239
            {
240
                int error_code = ::MatMult(matrix.InternalForReadOnly(), v1.InternalForReadOnly(), v2.Internal());
241 242
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMult", invoking_file, invoking_line);
243

244
                v2.UpdateGhosts(invoking_file, invoking_line, do_update_ghost);
245
            }
246 247


248 249
            template
            <
GILLES Sebastien's avatar
GILLES Sebastien committed
250 251
                class MatrixT,
                class MatrixU
252 253 254
            >
            std::enable_if_t
            <
GILLES Sebastien's avatar
GILLES Sebastien committed
255 256 257
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value,
                void
258 259 260
            >
            MatTranspose(MatrixT& matrix1,
                         MatrixU& matrix2,
261 262
                         const char* invoking_file, int invoking_line,
                         DoReuseMatrix do_reuse_matrix)
263
            {
264 265
                static_assert(std::is_const<MatrixU>() == false);

266
                Mat result;
267
                int error_code {};
268 269 270 271
                switch (do_reuse_matrix)
                {
                    case DoReuseMatrix::no:
                    {
272
                        error_code = ::MatTranspose(matrix1.InternalForReadOnly(), MAT_INITIAL_MATRIX, &result);
273 274 275 276 277 278
                        matrix2.SetFromPetscMat(result);
                        break;
                    }
                    case DoReuseMatrix::yes:
                    {
                        result = matrix2.Internal();
279
                        error_code = ::MatTranspose(matrix1.InternalForReadOnly(), MAT_REUSE_MATRIX, &result);
280 281 282 283
                        break;
                    }
                    case DoReuseMatrix::in_place:
                    {
284 285 286 287 288 289 290 291 292 293 294
                        result = matrix2.Internal();
                        assert(matrix1.InternalForReadOnly() == matrix2.Internal() && "For in place transpose both arguments"
                                                                                      "are expected to be pointers to the"
                                                                                      "same PETSc matrix object.");
                        // < note: Internal() and InternalForReadOnly() actually provide the same to the underlying
                        // pointer, but matrix1 might be const in some calls for the other values of do_reuse_matrix
                        // so Internal() couldn't be called on it reliably - to do so I would have to introduce
                        // a template class to enable partial specialization (I NEVER want to have to specify explicitly
                        // the matrix template parameters...).

                        error_code = ::MatTranspose(matrix2.Internal(), MAT_INPLACE_MATRIX, &result);
295 296 297 298
                        break;
                    }
                } // switch
                
299 300 301
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatTranspose", invoking_file, invoking_line);
            }
302 303


304
            template<class MatrixT>
305
            std::enable_if_t<std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value, void>
306 307 308 309
            MatMultAdd(const MatrixT& matrix,
                       const Vector& v1,
                       const Vector& v2,
                       Vector& v3,
310 311
                       const char* invoking_file, int invoking_line,
                       update_ghost do_update_ghost)
312
            {
313 314 315 316
                int error_code = ::MatMultAdd(matrix.InternalForReadOnly(),
                                              v1.InternalForReadOnly(),
                                              v2.InternalForReadOnly(),
                                              v3.Internal());
317 318
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMultAdd", invoking_file, invoking_line);
319

320
                v3.UpdateGhosts(invoking_file, invoking_line, do_update_ghost);
321
            }
322 323 324



325 326
            template
            <
327 328
                class MatrixT,
                class MatrixU
329
            >
330 331
            std::enable_if_t
            <
332 333
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value,
334 335 336
                void
            >
            MatCreateTranspose(const MatrixT& A,
337
                               MatrixU& transpose,
338 339 340
                               const char* invoking_file, int invoking_line)
            {
                Mat result;
341

342
                int error_code = ::MatCreateTranspose(A.InternalForReadOnly(), &result);
343 344
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatCreateTranspose", invoking_file, invoking_line);
345

346 347
                transpose.SetFromPetscMat(result);
            }
348 349


350 351
            template
            <
352 353 354
                class MatrixT,
                class MatrixU,
                class MatrixV
355 356 357
            >
            std::enable_if_t
            <
358
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
359 360
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixV>::value,
361 362 363
                void
            >
            MatTransposeMatMult(const MatrixT& matrix1,
364 365
                                const MatrixU& matrix2,
                                MatrixV& matrix3,
366 367
                                const char* invoking_file, int invoking_line,
                                DoReuseMatrix do_reuse_matrix)
368 369
            {
                Mat result;
370

371
                int error_code {};
372

373
                switch (do_reuse_matrix)
374 375 376 377
                {
                    case DoReuseMatrix::yes:
                    {
                        result = matrix3.Internal();
378

379 380
                        error_code = ::MatTransposeMatMult(matrix1.InternalForReadOnly(),
                                                           matrix2.InternalForReadOnly(),
381 382 383
                                                           MAT_REUSE_MATRIX,
                                                           PETSC_DEFAULT,
                                                           &result);
384

385 386 387 388
                        break;
                    }
                    case DoReuseMatrix::no:
                    {
389 390
                        error_code = ::MatTransposeMatMult(matrix1.InternalForReadOnly(),
                                                           matrix2.InternalForReadOnly(),
391 392 393
                                                           MAT_INITIAL_MATRIX,
                                                           PETSC_DEFAULT,
                                                           &result);
394

395
                        matrix3.SetFromPetscMat(result);
396

397 398
                        break;
                    }
399 400
                    case DoReuseMatrix::in_place:
                    {
401
                        static_cast<void>(error_code);
402
                        assert(false && "In place matrix option not supported for this function.");
403
                        exit(EXIT_FAILURE);
404 405
                        break;
                    }
406
                } // switch
407

408 409 410
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatTransposeMatMult", invoking_file, invoking_line);
            }
411 412


413 414
            template
            <
415 416 417
                class MatrixT,
                class MatrixU,
                class MatrixV
418 419 420
            >
            std::enable_if_t
            <
421
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
422 423
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixV>::value,
424 425 426
                void
            >
            MatMatTransposeMult(const MatrixT& matrix1,
427 428
                                const MatrixU& matrix2,
                                MatrixV& matrix3,
429 430
                                const char* invoking_file, int invoking_line,
                                DoReuseMatrix do_reuse_matrix)
431 432
            {
                Mat result;
433
                int error_code {};
434

435
                switch (do_reuse_matrix)
436 437 438 439
                {
                    case DoReuseMatrix::yes:
                    {
                        result = matrix3.Internal();
440 441
                        error_code = ::MatMatTransposeMult(matrix1.InternalForReadOnly(),
                                                           matrix2.InternalForReadOnly(),
442 443 444 445 446 447 448
                                                           MAT_REUSE_MATRIX,
                                                           PETSC_DEFAULT,
                                                           &result);
                        break;
                    }
                    case DoReuseMatrix::no:
                    {
449 450
                        error_code = ::MatMatTransposeMult(matrix1.InternalForReadOnly(),
                                                           matrix2.InternalForReadOnly(),
451 452 453
                                                           MAT_INITIAL_MATRIX,
                                                           PETSC_DEFAULT,
                                                           &result);
454

455 456 457
                        matrix3.SetFromPetscMat(result);
                        break;
                    }
458 459
                    case DoReuseMatrix::in_place:
                    {
460
                        static_cast<void>(error_code);
461
                        assert(false && "In place matrix option not supported for this function.");
462
                        exit(EXIT_FAILURE);
463 464
                        break;
                    }
465
                } // switch
466

467 468 469
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatMatTransposeMult", invoking_file, invoking_line);
            }
470 471 472



473 474
            template
            <
475 476 477
                class MatrixT,
                class MatrixU,
                class MatrixV
478 479 480
            >
            std::enable_if_t
            <
481
                std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixT>::value
482 483
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixU>::value
                && std::is_base_of<Internal::Wrappers::Petsc::BaseMatrix, MatrixV>::value,
484 485 486
                void
            >
            PtAP(const MatrixT& A,
487 488
                 const MatrixU& P,
                 MatrixV& out,
489 490
                 const char* invoking_file, int invoking_line,
                 DoReuseMatrix do_reuse_matrix)
491 492
            {
                Mat result;
493
                int error_code {};
494

495
                switch (do_reuse_matrix)
496 497 498 499
                {
                    case DoReuseMatrix::yes:
                    {
                        result = out.Internal();
500 501
                        error_code = ::MatPtAP(A.InternalForReadOnly(),
                                               P.InternalForReadOnly(),
502 503 504 505 506 507 508
                                               MAT_REUSE_MATRIX,
                                               PETSC_DEFAULT,
                                               &result);
                        break;
                    }
                    case DoReuseMatrix::no:
                    {
509 510
                        error_code = ::MatPtAP(A.InternalForReadOnly(),
                                               P.InternalForReadOnly(),
511 512 513 514 515 516
                                               MAT_INITIAL_MATRIX,
                                               PETSC_DEFAULT,
                                               &result);
                        out.SetFromPetscMat(result);
                        break;
                    }
517 518
                    case DoReuseMatrix::in_place:
                    {
519
                        static_cast<void>(error_code);
520
                        assert(false && "In place matrix option not supported for this function.");
521
                        exit(EXIT_FAILURE);
522 523
                        break;
                    }
524 525
                } // switch

526 527 528
                if (error_code)
                    throw ExceptionNS::Exception(error_code, "MatPtAP", invoking_file, invoking_line);
            }
529 530


531
        } // namespace Petsc
532 533


534
    } // namespace Wrappers
535 536


537
} // namespace MoReFEM
538 539


540 541 542
/// @} // addtogroup ThirdPartyGroup


543
#endif // MOREFEM_x_THIRD_PARTY_x_WRAPPERS_x_PETSC_x_MATRIX_x_MATRIX_OPERATIONS_HXX_