FQuickSortMpi.hpp 21.9 KB
Newer Older
1
// ===================================================================================
2
3
4
5
6
7
8
9
10
11
12
13
14
// Copyright ScalFmm 2011 INRIA, Olivier Coulaud, Bérenger Bramas, Matthias Messner
// olivier.coulaud@inria.fr, berenger.bramas@inria.fr
// This software is a computer program whose purpose is to compute the FMM.
//
// This software is governed by the CeCILL-C and LGPL licenses and
// abiding by the rules of distribution of free software.  
// 
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public and CeCILL-C Licenses for more details.
// "http://www.cecill.info". 
// "http://www.gnu.org/licenses".
15
16
17
18
19
20
// ===================================================================================
#ifndef FQUICKSORTMPI_HPP
#define FQUICKSORTMPI_HPP

#include "FQuickSort.hpp"
#include "FMpi.hpp"
21
#include "FLog.hpp"
22

23
#include <memory>
24
#include <utility>
25
26
27

template <class SortType, class CompareType, class IndexType>
class FQuickSortMpi : public FQuickSort< SortType, CompareType, IndexType> {
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    // We need a structure see the algorithm detail to know more
    struct Partition{
        IndexType lowerPart;
        IndexType greaterPart;
    };

    struct PackData {
        int idProc;
        IndexType fromElement;
        IndexType toElement;
    };


    static void Swap(SortType& value, SortType& other){
42
43
44
        const SortType temp = std::move(value);
        value = std::move(other);
        other = std::move(temp);
45
    }
46

47
48
49
50
51
52
53
    /* A local iteration of qs */
    static IndexType QsPartition(SortType array[], IndexType left, IndexType right, const CompareType& pivot){
        IndexType idx = left;
        while( idx <= right && CompareType(array[idx]) <= pivot){
            idx += 1;
        }
        left = idx;
54

55
56
57
58
59
60
        for( ; idx <= right ; ++idx){
            if( CompareType(array[idx]) <= pivot ){
                Swap(array[idx],array[left]);
                left += 1;
            }
        }
61

62
63
        return left;
    }
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    static std::vector<PackData> Distribute(const int currentRank, const int currentNbProcs ,
                                            const Partition globalElementBalance[], const Partition globalElementBalanceSum[],
                                            const int procInTheMiddle, const bool inFromRightToLeft){
        // First agree on who send and who recv
        const int firstProcToSend = (inFromRightToLeft ? procInTheMiddle+1 : 0);
        const int lastProcToSend  = (inFromRightToLeft ? currentNbProcs : procInTheMiddle+1);
        const int firstProcToRecv = (inFromRightToLeft ? 0 : procInTheMiddle+1);
        const int lastProcToRecv  = (inFromRightToLeft ? procInTheMiddle+1 : currentNbProcs);
        // Get the number of element depending on the lower or greater send/recv
        const IndexType totalElementToProceed = (inFromRightToLeft ?
                                                     globalElementBalanceSum[lastProcToSend].lowerPart - globalElementBalanceSum[firstProcToSend].lowerPart :
                                                     globalElementBalanceSum[lastProcToSend].greaterPart - globalElementBalanceSum[firstProcToSend].greaterPart);
        const IndexType totalElementAlreadyOwned = (inFromRightToLeft ?
                                                        globalElementBalanceSum[lastProcToRecv].lowerPart - globalElementBalanceSum[firstProcToRecv].lowerPart :
                                                        globalElementBalanceSum[lastProcToRecv].greaterPart - globalElementBalanceSum[firstProcToRecv].greaterPart);

        const int nbProcsToRecv = (lastProcToRecv-firstProcToRecv);
        const int nbProcsToSend = (lastProcToSend-firstProcToSend);
        const IndexType totalElements = (totalElementToProceed+totalElementAlreadyOwned);

        std::vector<IndexType> nbElementsToRecvPerProc;
        nbElementsToRecvPerProc.resize(nbProcsToRecv);
        {
            // Get the number of elements each proc should recv
            IndexType totalRemainingElements = totalElements;

            for(int idxProc = firstProcToRecv; idxProc < lastProcToRecv ; ++idxProc){
                const IndexType nbElementsAlreadyOwned = (inFromRightToLeft ? globalElementBalance[idxProc].lowerPart : globalElementBalance[idxProc].greaterPart);
                const IndexType averageNbElementForRemainingProc = (totalRemainingElements)/(lastProcToRecv-idxProc);
                totalRemainingElements -= nbElementsAlreadyOwned;
                if(nbElementsAlreadyOwned < averageNbElementForRemainingProc){
                    nbElementsToRecvPerProc[idxProc - firstProcToRecv] = (averageNbElementForRemainingProc - nbElementsAlreadyOwned);
                    totalRemainingElements -= nbElementsToRecvPerProc[idxProc - firstProcToRecv];
                }
                else{
                    nbElementsToRecvPerProc[idxProc - firstProcToRecv] = 0;
                }
                ////FLOG( FLog::Controller << currentRank << "] nbElementsToRecvPerProc[" << idxProc << "] = " << nbElementsToRecvPerProc[idxProc - firstProcToRecv] << "\n"; )
            }
        }
105

106
107
108
109
110
111
112
113
        // Store in an array the number of element to send
        std::vector<IndexType> nbElementsToSendPerProc;
        nbElementsToSendPerProc.resize(nbProcsToSend);
        for(int idxProc = firstProcToSend; idxProc < lastProcToSend ; ++idxProc){
            const IndexType nbElementsAlreadyOwned = (inFromRightToLeft ? globalElementBalance[idxProc].lowerPart : globalElementBalance[idxProc].greaterPart);
            nbElementsToSendPerProc[idxProc-firstProcToSend] = nbElementsAlreadyOwned;
            ////FLOG( FLog::Controller << currentRank << "] nbElementsToSendPerProc[" << idxProc << "] = " << nbElementsToSendPerProc[idxProc-firstProcToSend] << "\n"; )
        }
114

115
116
117
118
119
120
121
122
123
124
        // Compute all the send recv but keep only the ones related to currentRank
        std::vector<PackData> packs;
        int idxProcSend = 0;
        IndexType positionElementsSend = 0;
        int idxProcRecv = 0;
        IndexType positionElementsRecv = 0;
        while(idxProcSend != nbProcsToSend && idxProcRecv != nbProcsToRecv){
            if(nbElementsToSendPerProc[idxProcSend] == 0){
                idxProcSend += 1;
                positionElementsSend = 0;
125
            }
126
127
128
            else if(nbElementsToRecvPerProc[idxProcRecv] == 0){
                idxProcRecv += 1;
                positionElementsRecv = 0;
129
            }
130
131
132
133
134
135
136
137
            else {
                const IndexType nbElementsInPack = FMath::Min(nbElementsToSendPerProc[idxProcSend], nbElementsToRecvPerProc[idxProcRecv]);
                if(idxProcSend + firstProcToSend == currentRank){
                    PackData pack;
                    pack.idProc      = idxProcRecv + firstProcToRecv;
                    pack.fromElement = positionElementsSend;
                    pack.toElement   = pack.fromElement + nbElementsInPack;
                    packs.push_back(pack);
138
                }
139
140
141
142
143
144
                else if(idxProcRecv + firstProcToRecv == currentRank){
                    PackData pack;
                    pack.idProc      = idxProcSend + firstProcToSend;
                    pack.fromElement = positionElementsRecv;
                    pack.toElement   = pack.fromElement + nbElementsInPack;
                    packs.push_back(pack);
145
                }
146
147
148
149
                nbElementsToSendPerProc[idxProcSend] -= nbElementsInPack;
                nbElementsToRecvPerProc[idxProcRecv] -= nbElementsInPack;
                positionElementsSend += nbElementsInPack;
                positionElementsRecv += nbElementsInPack;
150
            }
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        }

        return packs;
    }

    static void RecvDistribution(SortType ** inPartRecv, IndexType* inNbElementsRecv,
                                 const Partition globalElementBalance[], const Partition globalElementBalanceSum[],
                                 const int procInTheMiddle, const FMpi::FComm& currentComm, const bool inFromRightToLeft){
        // Compute to know what to recv
        const std::vector<PackData> whatToRecvFromWho = Distribute(currentComm.processId(), currentComm.processCount(),
                                                                   globalElementBalance, globalElementBalanceSum,
                                                                   procInTheMiddle, inFromRightToLeft);
        // Count the total number of elements to recv
        IndexType totalToRecv = 0;
        for(const PackData& pack : whatToRecvFromWho){
            totalToRecv += pack.toElement - pack.fromElement;
        }
        // Alloc buffer
        SortType* recvBuffer = new SortType[totalToRecv];

        // Recv all data
        MPI_Request requests[whatToRecvFromWho.size()];
        for(int idxPack = 0 ; idxPack < int(whatToRecvFromWho.size()) ; ++idxPack){
            const PackData& pack = whatToRecvFromWho[idxPack];
            ////FLOG( FLog::Controller << currentComm.processId() << "] Recv from " << pack.idProc << " from " << pack.fromElement << " to " << pack.toElement << "\n"; )
            FMpi::Assert( MPI_Irecv((SortType*)&recvBuffer[pack.fromElement], int((pack.toElement - pack.fromElement) * sizeof(SortType)), MPI_BYTE, pack.idProc,
                          FMpi::TagQuickSort, currentComm.getComm(), &requests[idxPack]) , __LINE__);
        }
        // Wait to complete
        FMpi::Assert( MPI_Waitall(whatToRecvFromWho.size(), requests, MPI_STATUSES_IGNORE),  __LINE__ );
181
        ////FLOG( FLog::Controller << currentComm.processId() << "] Recv Done \n"; )
182
183
184
185
        // Copy to ouput variables
        (*inPartRecv) = recvBuffer;
        (*inNbElementsRecv) = totalToRecv;
    }
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    static void SendDistribution(const SortType * inPartToSend, const IndexType inNbElementsToSend,
                                 const Partition globalElementBalance[], const Partition globalElementBalanceSum[],
                                 const int procInTheMiddle, const FMpi::FComm& currentComm, const bool inFromRightToLeft){
        // Compute to know what to send
        const std::vector<PackData> whatToSendToWho = Distribute(currentComm.processId(), currentComm.processCount(),
                                                                 globalElementBalance, globalElementBalanceSum,
                                                                 procInTheMiddle, inFromRightToLeft);

        // Post send messages
        MPI_Request requests[whatToSendToWho.size()];
        for(int idxPack = 0 ; idxPack < int(whatToSendToWho.size()) ; ++idxPack){
            const PackData& pack = whatToSendToWho[idxPack];
            ////FLOG( FLog::Controller << currentComm.processId() << "] Send to " << pack.idProc << " from " << pack.fromElement << " to " << pack.toElement << "\n"; )
            FMpi::Assert( MPI_Isend((SortType*)&inPartToSend[pack.fromElement], int((pack.toElement - pack.fromElement) * sizeof(SortType)), MPI_BYTE , pack.idProc,
                          FMpi::TagQuickSort, currentComm.getComm(), &requests[idxPack]) , __LINE__);
        }
        // Wait to complete
        FMpi::Assert( MPI_Waitall(whatToSendToWho.size(), requests, MPI_STATUSES_IGNORE),  __LINE__ );
205
        ////FLOG( FLog::Controller << currentComm.processId() << "] Send Done \n"; )
206
    }
207

208
209
210
211
212
213
    static CompareType SelectPivot(const SortType workingArray[], const IndexType currentSize, const FMpi::FComm& currentComm, bool* shouldStop){
        enum ValuesState{
            ALL_THE_SAME,
            NO_VALUES,
            AVERAGE_2
        };
214
215
        // We need to know the max value to ensure that the pivot will be different
        CompareType maxFoundValue = CompareType(workingArray[0]);
216
217
218
219
220
221
        // Check if all the same
        bool allTheSame = true;
        for(int idx = 1 ; idx < currentSize && allTheSame; ++idx){
            if(workingArray[0] != workingArray[idx]){
                allTheSame = false;
            }
222
223
            // Keep the max
            maxFoundValue = FMath::Max(maxFoundValue , CompareType(workingArray[idx]));
224
225
226
227
228
229
230
        }
        // Check if empty
        const bool noValues = (currentSize == 0);
        // Get the local pivot if not empty
        CompareType localPivot = CompareType(0);
        if(!noValues){
            localPivot = (CompareType(workingArray[currentSize/3])+CompareType(workingArray[(2*currentSize)/3]))/2;
231
            // The pivot must be different (to ensure that the partition will return two parts)
232
233
            if( localPivot == maxFoundValue && !allTheSame){
                ////FLOG( FLog::Controller << currentComm.processId() << "] Pivot " << localPivot << " is equal max and allTheSame equal " << allTheSame << "\n"; )
234
235
                localPivot -= 1;
            }
236
        }
237

238
239
        ////FLOG( FLog::Controller << currentComm.processId() << "] localPivot = " << localPivot << "\n" );

240
        //const int myRank = currentComm.processId();
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        const int nbProcs = currentComm.processCount();
        // Exchange the pivos and the state
        std::unique_ptr<int[]> allProcStates(new int[nbProcs]);
        std::unique_ptr<CompareType[]> allProcPivots(new CompareType[nbProcs]);
        {
            int myState = (noValues?NO_VALUES:(allTheSame?ALL_THE_SAME:AVERAGE_2));
            FMpi::Assert( MPI_Allgather( &myState, 1, MPI_INT, allProcStates.get(),
                                         1, MPI_INT, currentComm.getComm()),  __LINE__ );
            FMpi::Assert( MPI_Allgather( &localPivot, sizeof(CompareType), MPI_BYTE, allProcPivots.get(),
                                         sizeof(CompareType), MPI_BYTE, currentComm.getComm()),  __LINE__ );
        }
        // Test if all the procs have ALL_THE_SAME and the same value
        bool allProcsAreSame = true;
        for(int idxProc = 0 ; idxProc < nbProcs && allProcsAreSame; ++idxProc){
            if(allProcStates[idxProc] != NO_VALUES && (allProcStates[idxProc] != ALL_THE_SAME || allProcPivots[0] != allProcPivots[idxProc])){
                allProcsAreSame = false;
            }
        }
259

260
261
262
263
264
265
266
267
268
269
270
271
272
        if(allProcsAreSame){
            // All the procs are the same so we need to stop
            (*shouldStop) = true;
            return CompareType(0);
        }
        else{
            CompareType globalPivot = 0;
            CompareType counterValuesInPivot = 0;
            // Compute the pivos
            for(int idxProc = 0 ; idxProc < nbProcs; ++idxProc){
                if(allProcStates[idxProc] != NO_VALUES){
                    globalPivot += allProcPivots[idxProc];
                    counterValuesInPivot += 1;
273
                }
274
275
276
277
278
            }
            (*shouldStop) = false;
            return globalPivot/counterValuesInPivot;
        }
    }
279

280
public:
281

282
283
284
285
286
287
    static void QsMpi(const SortType originalArray[], const IndexType originalSize,
                      SortType** outputArray, IndexType* outputSize, const FMpi::FComm& originalComm){
        // We do not work in place, so create a new array
        IndexType currentSize = originalSize;
        SortType* workingArray = new SortType[currentSize];
        FMemUtils::memcpy(workingArray, originalArray, sizeof(SortType) * currentSize);
288

289
290
        // Clone the MPI group because we will reduce it after each partition
        FMpi::FComm currentComm(originalComm.getComm());
291

292
293
294
295
296
297
298
299
        // Parallel sharing until I am alone on the data
        while( currentComm.processCount() != 1 ){
            // Agree on the Pivot
            bool shouldStop;
            const CompareType globalPivot = SelectPivot(workingArray, currentSize, currentComm, &shouldStop);
            if(shouldStop){
                ////FLOG( FLog::Controller << currentComm.processId() << "] shouldStop = " << shouldStop << "\n" );
                break;
300
301
            }

302
            ////FLOG( FLog::Controller << currentComm.processId() << "] globalPivot = " << globalPivot << "\n" );
303

304
305
306
            // Split the array in two parts lower equal to pivot and greater than pivot
            const IndexType nbLowerElements = QsPartition(workingArray, 0, currentSize-1, globalPivot);
            const IndexType nbGreaterElements = currentSize - nbLowerElements;
307

308
            ////FLOG( FLog::Controller << currentComm.processId() << "] After Partition: lower = " << nbLowerElements << " greater = " << nbGreaterElements << "\n"; )
309

310
311
            const int currentRank = currentComm.processId();
            const int currentNbProcs = currentComm.processCount();
312

313
314
315
            // We need to know what each process is holding
            Partition currentElementsBalance = { nbLowerElements, nbGreaterElements};
            Partition globalElementBalance[currentNbProcs];
316

317
318
319
            // Every one in the group need to know
            FMpi::Assert( MPI_Allgather( &currentElementsBalance, sizeof(Partition), MPI_BYTE, globalElementBalance,
                                         sizeof(Partition), MPI_BYTE, currentComm.getComm()),  __LINE__ );
320

321
322
323
324
325
326
327
328
329
330
331
            // Find the number of elements lower or greater
            IndexType globalNumberOfElementsGreater = 0;
            IndexType globalNumberOfElementsLower = 0;
            Partition globalElementBalanceSum[currentNbProcs + 1];
            globalElementBalanceSum[0].lowerPart = 0;
            globalElementBalanceSum[0].greaterPart = 0;
            for(int idxProc = 0 ; idxProc < currentNbProcs ; ++idxProc){
                globalElementBalanceSum[idxProc + 1].lowerPart = globalElementBalanceSum[idxProc].lowerPart + globalElementBalance[idxProc].lowerPart;
                globalElementBalanceSum[idxProc + 1].greaterPart = globalElementBalanceSum[idxProc].greaterPart + globalElementBalance[idxProc].greaterPart;
                globalNumberOfElementsGreater += globalElementBalance[idxProc].greaterPart;
                globalNumberOfElementsLower   += globalElementBalance[idxProc].lowerPart;
332
333
            }

334
335
336
337
338
339
340
            ////FLOG( FLog::Controller << currentComm.processId() << "] globalNumberOfElementsGreater = " << globalNumberOfElementsGreater << "\n"; )
            ////FLOG( FLog::Controller << currentComm.processId() << "] globalNumberOfElementsLower   = " << globalNumberOfElementsLower << "\n"; )

            // The proc rank in the middle from the percentage
            int procInTheMiddle;
            if(globalNumberOfElementsLower == 0)        procInTheMiddle = -1;
            else if(globalNumberOfElementsGreater == 0) procInTheMiddle = currentNbProcs-1;
341
            else procInTheMiddle = FMath::Min(IndexType(currentNbProcs-2), (currentNbProcs*globalNumberOfElementsLower)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                                              /(globalNumberOfElementsGreater + globalNumberOfElementsLower));

            ////FLOG( FLog::Controller << currentComm.processId() << "] procInTheMiddle = " << procInTheMiddle << "\n"; )

            // Send or receive depending on the state
            if(currentRank <= procInTheMiddle){
                // I am in the group of the lower elements
                SendDistribution(workingArray + nbLowerElements, nbGreaterElements,
                                 globalElementBalance, globalElementBalanceSum, procInTheMiddle, currentComm, false);
                SortType* lowerPartRecv = nullptr;
                IndexType nbLowerElementsRecv = 0;
                RecvDistribution(&lowerPartRecv, &nbLowerElementsRecv,
                                 globalElementBalance, globalElementBalanceSum, procInTheMiddle, currentComm, true);
                // Merge previous part and just received elements
                const IndexType fullNbLowerElementsRecv = nbLowerElementsRecv + nbLowerElements;
                SortType* fullLowerPart = new SortType[fullNbLowerElementsRecv];
                memcpy(fullLowerPart, workingArray, sizeof(SortType)* nbLowerElements);
                memcpy(fullLowerPart + nbLowerElements, lowerPartRecv, sizeof(SortType)* nbLowerElementsRecv);
                delete[] workingArray;
                delete[] lowerPartRecv;
                workingArray = fullLowerPart;
                currentSize = fullNbLowerElementsRecv;
                // Reduce working group
365
                ////FLOG( FLog::Controller << currentComm.processId() << "] Reduce group to " << 0 << " / " << procInTheMiddle << "\n"; )
366
                currentComm.groupReduce( 0, procInTheMiddle);
367
                ////FLOG( FLog::Controller << currentComm.processId() << "] Done\n" );
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
            }
            else {
                // I am in the group of the greater elements
                SortType* greaterPartRecv = nullptr;
                IndexType nbGreaterElementsRecv = 0;
                RecvDistribution(&greaterPartRecv, &nbGreaterElementsRecv,
                                 globalElementBalance, globalElementBalanceSum, procInTheMiddle, currentComm, false);
                SendDistribution(workingArray, nbLowerElements,
                                 globalElementBalance, globalElementBalanceSum, procInTheMiddle, currentComm, true);
                // Merge previous part and just received elements
                const IndexType fullNbGreaterElementsRecv = nbGreaterElementsRecv + nbGreaterElements;
                SortType* fullGreaterPart = new SortType[fullNbGreaterElementsRecv];
                memcpy(fullGreaterPart, workingArray + nbLowerElements, sizeof(SortType)* nbGreaterElements);
                memcpy(fullGreaterPart + nbGreaterElements, greaterPartRecv, sizeof(SortType)* nbGreaterElementsRecv);
                delete[] workingArray;
                delete[] greaterPartRecv;
                workingArray = fullGreaterPart;
                currentSize = fullNbGreaterElementsRecv;
                // Reduce working group
387
                ////FLOG( FLog::Controller << currentComm.processId() << "] Reduce group to " << procInTheMiddle + 1 << " / " << currentNbProcs - 1 << "\n"; )
388
                currentComm.groupReduce( procInTheMiddle + 1, currentNbProcs - 1);
389
                ////FLOG( FLog::Controller << currentComm.processId() << "] Done\n"; )
390
            }
391
392
        }

393
        ////FLOG( FLog::Controller << currentComm.processId() << "] Sequential sort\n"; )
394
395
396
397
        // Finish by a local sort
        FQuickSort< SortType, CompareType, IndexType>::QsOmp(workingArray, currentSize);
        (*outputSize)  = currentSize;
        (*outputArray) = workingArray;
398
399
400
401
    }
};

#endif // FQUICKSORTMPI_HPP