FQuickSortMpi.hpp 10.8 KB
Newer Older
1
// ===================================================================================
2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright ScalFmm 2011 INRIA, Olivier Coulaud, Bérenger Bramas, Matthias Messner
// olivier.coulaud@inria.fr, berenger.bramas@inria.fr
// This software is a computer program whose purpose is to compute the FMM.
//
// This software is governed by the CeCILL-C and LGPL licenses and
// abiding by the rules of distribution of free software.  
// 
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public and CeCILL-C Licenses for more details.
// "http://www.cecill.info". 
// "http://www.gnu.org/licenses".
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
// ===================================================================================
#ifndef FQUICKSORTMPI_HPP
#define FQUICKSORTMPI_HPP

#include "FQuickSort.hpp"
#include "FMpi.hpp"


template <class SortType, class CompareType, class IndexType>
class FQuickSortMpi : public FQuickSort< SortType, CompareType, IndexType> {
public:
    /* the mpi qs */
    static void QsMpi(const SortType originalArray[], IndexType size, SortType* & outputArray, IndexType& outputSize, const FMpi::FComm& originalComm){
        FTRACE( FTrace::FFunction functionTrace(__FUNCTION__, "Quicksort" , __FILE__ , __LINE__) );
        // We need a structure see the algorithm detail to know more
        struct Fix{
            IndexType pre;
            IndexType suf;
        };

        // first we copy data into our working buffer : outputArray
        outputArray = new SortType[size];
        FMemUtils::memcpy(outputArray, originalArray, sizeof(SortType) * size);
        outputSize = size;

        // alloc outputArray to store pre/sufixe, maximum needed is nb procs[comm world] + 1
        Fix fixes[originalComm.processCount() + 1];
        Fix fixesSum[originalComm.processCount() + 1];
        memset(fixes,0,sizeof(Fix) * originalComm.processCount());
        memset(fixesSum,0,sizeof(Fix) * (originalComm.processCount() + 1) );

        // receiving buffer
        IndexType bufferSize = 0;
        SortType* buffer = 0;

        // Create the first com
        FMpi::FComm currentComm(originalComm.getComm());

        // While I am not working alone on my own data
        while( currentComm.processCount() != 1 ){
            const int currentRank = currentComm.processId();
            const int currentNbProcs = currentComm.processCount();

            MPI_Request requests[currentNbProcs * 2];
            int iterRequest = 0;

            /////////////////////////////////////////////////
            // Local sort
            /////////////////////////////////////////////////

            // sort QsLocal part of the outputArray
            const CompareType pivot = currentComm.reduceAverageAll( CompareType(outputArray[size/2]) );
            Fix myFix;
            QsLocal(outputArray, pivot, 0, size - 1, myFix.pre, myFix.suf);

            // exchange fixes
            FMpi::Assert( MPI_Allgather( &myFix, sizeof(Fix), MPI_BYTE, fixes, sizeof(Fix), MPI_BYTE, currentComm.getComm()),  __LINE__ );

            // each procs compute the summation
            fixesSum[0].pre = 0;
            fixesSum[0].suf = 0;
            for(int idxProc = 0 ; idxProc < currentNbProcs ; ++idxProc){
                fixesSum[idxProc + 1].pre = fixesSum[idxProc].pre + fixes[idxProc].pre;
                fixesSum[idxProc + 1].suf = fixesSum[idxProc].suf + fixes[idxProc].suf;
            }

            // then I need to know which procs will be in the middle
            int splitProc = FMpi::GetProc(fixesSum[currentNbProcs].pre - 1, fixesSum[currentNbProcs].pre + fixesSum[currentNbProcs].suf, currentNbProcs);
            if(splitProc == currentNbProcs - 1){
                --splitProc;
            }

            /////////////////////////////////////////////////
            // Send my data
            /////////////////////////////////////////////////

            // above pivot (right part)
            if( fixes[currentRank].suf ){
                const int procsInSuf = currentNbProcs - 1 - splitProc;
                const int firstProcInSuf = splitProc + 1;
                const IndexType elementsInSuf = fixesSum[currentNbProcs].suf;

                const int firstProcToSend = FMpi::GetProc(fixesSum[currentRank].suf, elementsInSuf, procsInSuf) + firstProcInSuf;
                const int lastProcToSend = FMpi::GetProc(fixesSum[currentRank + 1].suf - 1, elementsInSuf, procsInSuf) + firstProcInSuf;

                IndexType sent = 0;
                for(int idxProc = firstProcToSend ; idxProc <= lastProcToSend ; ++idxProc){
                    const IndexType thisProcRight = FMpi::GetRight(elementsInSuf, idxProc - firstProcInSuf, procsInSuf);
                    IndexType sendToProc = thisProcRight - fixesSum[currentRank].suf - sent;

                    if(sendToProc + sent > fixes[currentRank].suf){
                        sendToProc = fixes[currentRank].suf - sent;
                    }
                    if( sendToProc ){
                        FMpi::Assert( MPI_Isend(&outputArray[sent + fixes[currentRank].pre], int(sendToProc * sizeof(SortType)), MPI_BYTE , idxProc, FMpi::TagQuickSort, currentComm.getComm(), &requests[iterRequest++]),  __LINE__ );
                    }
                    sent += sendToProc;
                }
            }

            // under pivot (left part)
            if( fixes[currentRank].pre ){
                const int procsInPre = splitProc + 1;
                const IndexType elementsInPre = fixesSum[currentNbProcs].pre;

                const int firstProcToSend = FMpi::GetProc(fixesSum[currentRank].pre, elementsInPre, procsInPre);
                const int lastProcToSend = FMpi::GetProc(fixesSum[currentRank + 1].pre - 1, elementsInPre, procsInPre);

                IndexType sent = 0;
                for(int idxProc = firstProcToSend ; idxProc <= lastProcToSend ; ++idxProc){
                    const IndexType thisProcRight = FMpi::GetRight(elementsInPre, idxProc, procsInPre);
                    IndexType sendToProc = thisProcRight - fixesSum[currentRank].pre - sent;

                    if(sendToProc + sent > fixes[currentRank].pre){
                        sendToProc = fixes[currentRank].pre - sent;
                    }
                    if(sendToProc){
                        FMpi::Assert( MPI_Isend(&outputArray[sent], int(sendToProc * sizeof(SortType)), MPI_BYTE , idxProc, FMpi::TagQuickSort, currentComm.getComm(), &requests[iterRequest++]),  __LINE__ );
                    }
                    sent += sendToProc;
                }
            }

            /////////////////////////////////////////////////
            // Receive data that belong to me
            /////////////////////////////////////////////////

            if( currentRank <= splitProc ){
                // I am in S-Part (smaller than pivot)
                const int procsInPre = splitProc + 1;
                const IndexType elementsInPre = fixesSum[currentNbProcs].pre;

                IndexType myLeft = FMpi::GetLeft(elementsInPre, currentRank, procsInPre);
                IndexType myRightLimit = FMpi::GetRight(elementsInPre, currentRank, procsInPre);

                size = myRightLimit - myLeft;
                if(bufferSize < size){
                    bufferSize = size;
                    delete[] buffer;
                    buffer = new SortType[bufferSize];
                }

                int idxProc = 0;
                while( idxProc < currentNbProcs && fixesSum[idxProc + 1].pre <= myLeft ){
                    ++idxProc;
                }

                IndexType indexArray = 0;

                while( idxProc < currentNbProcs && indexArray < myRightLimit - myLeft){
                    const IndexType firstIndex = FMath::Max(myLeft , fixesSum[idxProc].pre );
                    const IndexType endIndex = FMath::Min(fixesSum[idxProc + 1].pre,  myRightLimit);
                    if( (endIndex - firstIndex) ){
                        FMpi::Assert( MPI_Irecv(&buffer[indexArray], int((endIndex - firstIndex) * sizeof(SortType)), MPI_BYTE, idxProc, FMpi::TagQuickSort, currentComm.getComm(), &requests[iterRequest++]),  __LINE__ );
                    }
                    indexArray += endIndex - firstIndex;
                    ++idxProc;
                }
                // Proceed all send/receive
                FMpi::Assert( MPI_Waitall(iterRequest, requests, MPI_STATUSES_IGNORE),  __LINE__ );

                currentComm.groupReduce( 0, splitProc);
            }
            else{
                // I am in L-Part (larger than pivot)
                const int procsInSuf = currentNbProcs - 1 - splitProc;
                const IndexType elementsInSuf = fixesSum[currentNbProcs].suf;

                const int rankInL = currentRank - splitProc - 1;
                IndexType myLeft = FMpi::GetLeft(elementsInSuf, rankInL, procsInSuf);
                IndexType myRightLimit = FMpi::GetRight(elementsInSuf, rankInL, procsInSuf);

                size = myRightLimit - myLeft;
                if(bufferSize < size){
                    bufferSize = size;
                    delete[] buffer;
                    buffer = new SortType[bufferSize];
                }

                int idxProc = 0;
                while( idxProc < currentNbProcs && fixesSum[idxProc + 1].suf <= myLeft ){
                    ++idxProc;
                }

                IndexType indexArray = 0;

                while( idxProc < currentNbProcs && indexArray < myRightLimit - myLeft){
                    const IndexType firstIndex = FMath::Max(myLeft , fixesSum[idxProc].suf );
                    const IndexType endIndex = FMath::Min(fixesSum[idxProc + 1].suf,  myRightLimit);
                    if( (endIndex - firstIndex) ){
                        FMpi::Assert( MPI_Irecv(&buffer[indexArray], int((endIndex - firstIndex) * sizeof(SortType)), MPI_BYTE, idxProc, FMpi::TagQuickSort, currentComm.getComm(), &requests[iterRequest++]),  __LINE__ );
                    }
                    indexArray += endIndex - firstIndex;
                    ++idxProc;
                }
                // Proceed all send/receive
                FMpi::Assert( MPI_Waitall(iterRequest, requests, MPI_STATUSES_IGNORE),  __LINE__ );

                currentComm.groupReduce( splitProc + 1, currentNbProcs - 1);
            }



            // Copy res into outputArray
            if(outputSize < size){
                delete[] outputArray;
                outputArray = new SortType[size];
                outputSize = size;
            }

            FMemUtils::memcpy(outputArray, buffer, sizeof(SortType) * size);
        }

        /////////////////////////////////////////////////
        // End QsMpi sort
        /////////////////////////////////////////////////

        // Clean
        delete[] buffer;

        // Normal Quick sort
        QsOmp(outputArray, size);
        outputSize = size;
    }

};

#endif // FQUICKSORTMPI_HPP