diff --git a/tools/gen_wrappers.py b/tools/gen_wrappers.py index 87bbe3bb17654388839d0d4bcba6d0d17fd51263..d33f0280559d124d7a162e72f95bd3f39360aa34 100755 --- a/tools/gen_wrappers.py +++ b/tools/gen_wrappers.py @@ -479,32 +479,38 @@ spm_enums = { 'filename' : [ "include/spm/const.h" ], 'python' : { 'filename' : "wrappers/python/spm/enum.py.in", 'description' : "SPM python wrapper to define enums and datatypes", - 'header' : "# Start with __ to prevent broadcast to file importing enum\n" - "__spm_int__ = @SPM_PYTHON_INTEGER@\n" - "__spm_mpi_enabled__ = @SPM_PYTHON_MPI_ENABLED@\n", + 'header' : """ +# Start with __ to prevent broadcast to file importing enum +__spm_int__ = @SPM_PYTHON_INTEGER@ +__spm_mpi_enabled__ = @SPM_PYTHON_MPI_ENABLED@ +""", 'footer' : "", 'enums' : { 'coeftype' : enums_python_coeftype, 'mtxtype' : " SymPosDef = trans.ConjTrans + 1\n HerPosDef = trans.ConjTrans + 2\n" } }, 'fortran' : { 'filename' : "wrappers/fortran90/src/spm_enums.F90", 'description' : "SPM fortran 90 wrapper to define enums and datatypes", - 'header' : "#if defined(SPM_WITH_MPI)\n" - " use mpi_f08\n" - "#endif\n" - " implicit none\n" - "\n" - "#if !defined(SPM_WITH_MPI)\n" - " type, bind(c) :: MPI_Comm\n" - " integer(kind=c_int) :: MPI_Comm\n" - " end type MPI_Comm\n" - "#endif\n", + 'header' : """ +#if defined(SPM_WITH_MPI) + use mpi_f08 +#endif + implicit none + +#if !defined(SPM_WITH_MPI) + type, bind(c) :: MPI_Comm + integer(kind=c_int) :: MPI_Comm + end type MPI_Comm +#endif +""", 'footer' : enums_fortran_footer, 'enums' : { 'mtxtype' : " enumerator :: SpmSymPosDef = SpmConjTrans + 1\n enumerator :: SpmHerPosDef = SpmConjTrans + 2\n" } }, 'julia' : { 'filename' : "wrappers/julia/spm/src/spm_enums.jl.in", 'description' : "SPM julia wrapper to define enums and datatypes", - 'header' :"const spm_int_t = @SPM_JULIA_INTEGER@\n" - "const spm_mpi_enabled = @SPM_JULIA_MPI_ENABLED@\n", + 'header' :""" +const spm_int_t = @SPM_JULIA_INTEGER@ +const spm_mpi_enabled = @SPM_JULIA_MPI_ENABLED@ +""", 'footer' : "", 'enums' : {} }, @@ -514,57 +520,75 @@ spm = { 'filename' : [ "include/spm.h" ], 'python' : { 'filename' : "wrappers/python/spm/__spm__.py", 'description' : "SPM python wrapper", - 'header' : "from . import libspm\n" - "from .enum import __spm_int__\n" - "from .enum import __spm_mpi_enabled__\n" - "\n" - "if __spm_mpi_enabled__:\n" - " from mpi4py import MPI\n" - "\n" - "def __get_mpi_type__():\n" - " if not __spm_mpi_enabled__:\n" - " return c_int\n" - " if MPI._sizeof(MPI.Comm) == sizeof(c_long):\n" - " return c_long\n" - " elif MPI._sizeof(MPI.Comm) == sizeof(c_int):\n" - " return c_int\n" - " else:\n" - " return c_void_p\n", + 'header' : """ +from . import libspm +from .enum import __spm_int__ +from .enum import __spm_mpi_enabled__ + +if __spm_mpi_enabled__: + from mpi4py import MPI + if MPI._sizeof(MPI.Comm) == sizeof(c_long): + pyspm_mpi_comm = c_long + elif MPI._sizeof(MPI.Comm) == sizeof(c_int): + pyspm_mpi_comm = c_int + else: + pyspm_mpi_comm = c_void_p + + pyspm_default_comm = MPI.COMM_WORLD + + def pyspm_convert_comm( comm ): + comm_ptr = MPI._addressof(comm) + return pyspm_mpi_comm.from_address(comm_ptr) +else: + pyspm_mpi_comm = c_int + + pyspm_default_comm = 0 + + def pyspm_convert_comm( comm ): + return c_int(comm) +""", 'footer' : "", 'enums' : {} }, 'fortran' : { 'filename' : "wrappers/fortran90/src/spmf.f90", 'description' : "SPM Fortran 90 wrapper", - 'header' : " use spm_enums\n" - " implicit none\n", + 'header' : """ + use spm_enums + implicit none +""", 'footer' : "", 'enums' : {} }, 'julia' : { 'filename' : "wrappers/julia/spm/src/spm.jl", 'description' : "SPM julia wrapper", - 'header' : "module spm\n" - "using CBinding\n" - "using Libdl\n" - "include(\"spm_enums.jl\")\n\n" - "function spm_library_path()\n" - " x = Libdl.dlext\n" - " return \"libspm.$x\"\n" - "end\n" - "libspm = spm_library_path()\n\n" - "if spm_mpi_enabled\n" - " using MPI\n" - "end\n\n" - "function __get_mpi_type__()\n" - " if !spm_mpi_enabled\n" - " return Cint\n" - " elseif sizeof(MPI.MPI_Comm) == sizeof(Clong)\n" - " return Clong\n" - " elseif sizeof(MPI.MPI_Comm) == sizeof(Cint)\n" - " return Cint\n" - " end\n" - " return Cvoid\n" - "end\n", - 'footer' : "end #module", + 'header' : """ +module spm +using CBinding +using Libdl +include(\"spm_enums.jl\") + +function spm_library_path() + x = Libdl.dlext + return \"libspm.$x\" +end +libspm = spm_library_path() + +if spm_mpi_enabled + using MPI +end + +function __get_mpi_type__() + if !spm_mpi_enabled + return Cint + elseif sizeof(MPI.MPI_Comm) == sizeof(Clong) + return Clong + elseif sizeof(MPI.MPI_Comm) == sizeof(Cint) + return Cint + end + return Cvoid +end +""", + 'footer' : "end #module", 'enums' : {} }, } diff --git a/tools/wrappers/wrap_fortran.py b/tools/wrappers/wrap_fortran.py index 735def330f230074268cbfa545ffa9bf27d5be95..d2e721a3c82d76be6c396f72a72de4a6eee07501 100644 --- a/tools/wrappers/wrap_fortran.py +++ b/tools/wrappers/wrap_fortran.py @@ -171,8 +171,7 @@ class wrap_fortran: !> @ingroup wrap_fortran !> module ''' + modname + ''' - use iso_c_binding -''' + use iso_c_binding''' if f['header'] != "": header += f['header'] diff --git a/tools/wrappers/wrap_julia.py b/tools/wrappers/wrap_julia.py index d47dd14b1b769441c0410fa8d33b02dd65e10baf..b1209ddb37b242f60f387502d874d404c72368bf 100644 --- a/tools/wrappers/wrap_julia.py +++ b/tools/wrappers/wrap_julia.py @@ -120,7 +120,7 @@ class wrap_julia: =# ''' if f['header'] != "": - header += "\n" + f['header'] + header += f['header'] return header; @staticmethod diff --git a/tools/wrappers/wrap_python.py b/tools/wrappers/wrap_python.py index bbb4f64e81687b1532c75e703850f2135c562c87..8c7f7c6aa388bad2e11e4ae7292a46bf33fe7d32 100644 --- a/tools/wrappers/wrap_python.py +++ b/tools/wrappers/wrap_python.py @@ -50,7 +50,7 @@ types_dict = { "double": ("c_double"), "float": ("c_float"), "void": ("c_void"), - "MPI_Comm": ("__get_mpi_type__()"), + "MPI_Comm": ("pyspm_mpi_comm"), "FILE": ("c_void"), } @@ -121,6 +121,10 @@ def iso_c_wrapper_type(arg, args_list, args_size): if arg[1] == "**": f_call = "pointer( " + f_call + " )" + # Call to communicators + if (arg[0] == "PASTIX_Comm") or (arg[0] == "MPI_Comm"): + f_call = "pyspm_convert_comm( " + f_call + " )" + args_size[0] = max(args_size[0], len(f_name)) args_size[1] = max(args_size[1], len(f_type)) args_size[2] = max(args_size[2], len(f_call)) @@ -156,7 +160,7 @@ from ctypes import * import numpy as np ''' if f['header'] != "": - header += "\n" + f['header'] + header += f['header'] return header; @staticmethod diff --git a/wrappers/julia/spm/src/spm.jl b/wrappers/julia/spm/src/spm.jl index 25bf2afeef35c7c8b041a2808a605e9a936bbd87..3ea2ed96d0eeceb35be6de2ec2434db2a99af8d3 100644 --- a/wrappers/julia/spm/src/spm.jl +++ b/wrappers/julia/spm/src/spm.jl @@ -11,7 +11,7 @@ @author Mathieu Faverge @author Selmane Lebdaoui @author Tony Delarue - @date 2021-04-04 + @date 2021-04-07 This file has been automatically generated with gen_wrappers.py @@ -35,14 +35,14 @@ if spm_mpi_enabled end function __get_mpi_type__() - if !spm_mpi_enabled - return Cint - elseif sizeof(MPI.MPI_Comm) == sizeof(Clong) - return Clong - elseif sizeof(MPI.MPI_Comm) == sizeof(Cint) - return Cint - end - return Cvoid + if !spm_mpi_enabled + return Cint + elseif sizeof(MPI.MPI_Comm) == sizeof(Clong) + return Clong + elseif sizeof(MPI.MPI_Comm) == sizeof(Cint) + return Cint + end + return Cvoid end @cstruct spmatrix_t { @@ -227,4 +227,4 @@ end @cextern spmDofExtend( spm::Ptr{spmatrix_t}, type::Cint, dof::Cint )::Ptr{spmatrix_t} end -end #module +end #module diff --git a/wrappers/python/spm/__spm__.py b/wrappers/python/spm/__spm__.py index 416667482f228649e4273177d4fd700a4aa1233b..20a6fc52d1bbf6df5f468ce474e3b40516cfd9a6 100644 --- a/wrappers/python/spm/__spm__.py +++ b/wrappers/python/spm/__spm__.py @@ -11,7 +11,7 @@ @author Pierre Ramet @author Mathieu Faverge @author Tony Delarue - @date 2021-04-04 + @date 2021-04-07 This file has been automatically generated with gen_wrappers.py @@ -27,16 +27,25 @@ from .enum import __spm_mpi_enabled__ if __spm_mpi_enabled__: from mpi4py import MPI - -def __get_mpi_type__(): - if not __spm_mpi_enabled__: - return c_int if MPI._sizeof(MPI.Comm) == sizeof(c_long): - return c_long + pyspm_mpi_comm = c_long elif MPI._sizeof(MPI.Comm) == sizeof(c_int): - return c_int + pyspm_mpi_comm = c_int else: - return c_void_p + pyspm_mpi_comm = c_void_p + + pyspm_default_comm = MPI.COMM_WORLD + + def pyspm_convert_comm( comm ): + comm_ptr = MPI._addressof(comm) + return pyspm_mpi_comm.from_address(comm_ptr) +else: + pyspm_mpi_comm = c_int + + pyspm_default_comm = 0 + + def pyspm_convert_comm( comm ): + return c_int(comm) class pyspm_spmatrix_t(Structure): _fields_ = [("mtxtype", c_int ), @@ -61,7 +70,7 @@ class pyspm_spmatrix_t(Structure): ("glob2loc", POINTER(__spm_int__)), ("clustnum", c_int ), ("clustnbr", c_int ), - ("comm", __get_mpi_type__() ) ] + ("comm", pyspm_mpi_comm ) ] def pyspm_spmInit( spm ): libspm.spmInit.argtypes = [ POINTER(pyspm_spmatrix_t) ] @@ -103,18 +112,17 @@ def pyspm_spmGenFakeValues( spm ): libspm.spmGenFakeValues( spm ) def pyspm_spmInitDist( spm, comm ): - libspm.spmInitDist.argtypes = [ POINTER(pyspm_spmatrix_t), - __get_mpi_type__() ] - libspm.spmInitDist( spm, comm ) + libspm.spmInitDist.argtypes = [ POINTER(pyspm_spmatrix_t), pyspm_mpi_comm ] + libspm.spmInitDist( spm, pyspm_convert_comm( comm ) ) def pyspm_spmScatter( spm, n, loc2glob, distByColumn, root, comm ): libspm.spmScatter.argtypes = [ POINTER(pyspm_spmatrix_t), __spm_int__, POINTER(__spm_int__), c_int, c_int, - __get_mpi_type__() ] + pyspm_mpi_comm ] libspm.spmScatter.restype = POINTER(pyspm_spmatrix_t) return libspm.spmScatter( spm, n, loc2glob.ctypes.data_as( POINTER(__spm_int__) ), - distByColumn, root, comm ) + distByColumn, root, pyspm_convert_comm( comm ) ) def pyspm_spmGather( spm, root ): libspm.spmGather.argtypes = [ POINTER(pyspm_spmatrix_t), c_int ] diff --git a/wrappers/python/spm/spm.py b/wrappers/python/spm/spm.py index d67fcd8bce6ae3d7a1a6b1beb394b90ecd5b0e73..e111ac9dbadbef84ba0e2f69e52b311b05a69436 100644 --- a/wrappers/python/spm/spm.py +++ b/wrappers/python/spm/spm.py @@ -35,7 +35,7 @@ class spmatrix(): dtype = None - def __init__( self, A=None, mtxtype_=mtxtype.General, driver=None, filename="" ): + def __init__( self, A=None, mtxtype_=mtxtype.General, driver=None, filename="", comm=pyspm_default_comm ): """ Initialize the SPM wrapper by loading the libraries """ @@ -50,9 +50,10 @@ class spmatrix(): 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, None, layout.ColMajor, - None, None, None, None ) + None, None, None, None, None, + 0, 1, pyspm_convert_comm( comm ) ) self.id_ptr = pointer( self.spm_c ) - self.init() + self.init( comm ) if A is not None: self.fromsps( A, mtxtype_ ) @@ -157,8 +158,8 @@ class spmatrix(): def printSpm( self ): pyspm_spmPrint( self.id_ptr ) - def init( self ): - pyspm_spmInit( self.id_ptr ) + def init( self, comm=pyspm_default_comm ): + pyspm_spmInitDist( self.id_ptr, comm ) def checkAndCorrect( self ): spm1 = self.id_ptr