FSVDBlock.hpp 7.78 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
// ===================================================================================
// Copyright ScalFmm 2011 INRIA, Olivier Coulaud, Berenger Bramas, Matthias Messner
// olivier.coulaud@inria.fr, berenger.bramas@inria.fr
// This software is a computer program whose purpose is to compute the FMM.
//
// This software is governed by the CeCILL-C and LGPL licenses and
// abiding by the rules of distribution of free software.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public and CeCILL-C Licenses for more details.
// "http://www.cecill.info".
// "http://www.gnu.org/licenses".
// ===================================================================================
// 
// @SCALFMM_PRIVATE
// 
#ifndef FSVDBLOCK_HPP
#define FSVDBLOCK_HPP

#include "Utils/FBlas.hpp"

/*
 * Compute SVD $A=USV'$ and return $V'$, $S$ and $U$.
 * \param A contains input M x N matrix to be decomposed
 * \param S contains singular values $S$
 * \param U contains $U$
 * \param VT contains $V'$
 */
template<class FReal>
static void computeSVD(const FSize nbRows, const FSize nbCols, const FReal* A, FReal* S, FReal* U, FReal* VT){
    // verbose
    const bool verbose = false;
    // copy A
    //is_int(size*size);
    FBlas::copy(int(nbRows*nbCols),A,U);
    // init SVD
    const FSize minMN = std::min(nbRows,nbCols);
    const FSize maxMN = std::max(nbRows,nbCols);
    //const FSize LWORK = 2*4*minMN; // for square matrices
    const FSize LWORK = 2*std::max(3*minMN+maxMN, 5*minMN);
    FReal *const WORK = new FReal [LWORK];
    // singular value decomposition
    if(verbose) std::cout << "\nPerform SVD...";
    // SO means that first min(m,n) lines of U overwritten on VT and V' on U (A=VTSU)
    // AA means all lines
    // nothing means OS (the opposite of SO, A=USVT)
    //is_int(size); is_int(LWORK);
    const unsigned int INFOSVD
51
    = FBlas::gesvd(int(nbRows), int(nbCols), U, S, VT, int(minMN)/*ldVT*/,
52 53 54 55 56 57 58 59 60
                   int(LWORK), WORK);
    if(verbose) {
        if(INFOSVD!=0) {std::cout << " failed!" << std::endl;}
        else {std::cout << " succeed!" << std::endl;}
    }
    // free memory
    delete[] WORK;
}

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

/*
 * Compute SVD $A=USV'$ and return $V'$, $S$ and $U$.
 * \param 
 */
template<class FReal>
static void computeNumericalRank(int &rank, const FReal* S, const FReal epsilon){
    // verbose
    const bool verbose = false;

    // init
    const FSize maxRank = rank;
    FReal sumSigma2 = FReal(0.0);
    for(int idxRow = 0 ; idxRow < rank ; ++idxRow)
        sumSigma2+=S[idxRow]*S[idxRow];
76
    FReal SqrtSumSigma2 = std::sqrt(sumSigma2);
77 78 79 80 81

    // set rank to 1
    rank = 1;
    // increase
    FReal sumSigma2r = S[0]*S[0];
82
    while(std::sqrt(sumSigma2r)<(FReal(1.)-epsilon)*SqrtSumSigma2 && rank<maxRank){
83 84 85
        sumSigma2r+=S[rank]*S[rank];
        rank++;
    }
86 87 88
    //std::cout << "std::sqrt(sumSigma2r)=" << std::sqrt(sumSigma2r) << std::endl;
    //std::cout << "std::sqrt(sumSigma2)=" << std::sqrt(sumSigma2) << std::endl;
    //std::cout << "R/S=" << (std::sqrt(sumSigma2)-std::sqrt(sumSigma2r))/std::sqrt(sumSigma2) << std::endl;
89 90 91 92

}


93 94


95
template <class FReal, int ORDER = 14>
96 97 98 99 100 101 102 103 104 105 106
class FSVDBlock{
protected:
    // members
    FReal* block;
    FReal* U;
    FReal* S;
    FReal* VT;
    int nbRows;
    int nbCols;
    int level;
    int rank;
107
    FReal accuracy;
108 109 110 111 112
    FSVDBlock(const FSVDBlock&) = delete;
    FSVDBlock& operator=(const FSVDBlock&) = delete;

public:
    FSVDBlock()
113
        : block(nullptr), U(nullptr), S(nullptr), VT(nullptr), nbRows(0), nbCols(0),  level(0), rank(0), accuracy(FMath::pow(FReal(10.0),static_cast<FReal>(-ORDER))) {
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    }

    // ctor
    template <class ViewerClass>
    void fill(const ViewerClass& viewer, const int inLevel){
        clear();
        // Allocate memory
        level  = inLevel;
        nbRows = viewer.getNbRows();
        nbCols = viewer.getNbCols();
        block  = new FReal[nbRows*nbCols];

        for(int idxRow = 0 ; idxRow < nbRows ; ++idxRow){
            for(int idxCol = 0 ; idxCol < nbCols ; ++idxCol){
                block[idxCol*nbRows+idxRow] = viewer.getValue(idxRow,idxCol);
            }
        }

        // SVD specific (col major)
        rank = std::min(nbRows,nbCols);
        S  = new FReal[rank];
135
        FReal* _U  = new FReal[nbRows*nbCols]; // Call to computeSVD() copies block MxN into _U and stores first min(M,N) cols of U into _U
136
        FReal* _VT = new FReal[rank*nbCols];
137
        FBlas::setzero(int(rank), S);        
138
        FBlas::setzero(int(rank*nbCols),_VT);
139
        // Perform decomposition of rectangular block (jobu=O, jobvt=S => only first min(M,N) cols/rows of U/VT are stored)
140
        computeSVD(nbRows, nbCols, block, S, _U ,_VT);
141

142 143 144
        // Determine numerical rank using prescribed accuracy
        computeNumericalRank(rank, S, accuracy);

145 146 147 148 149 150
        //// Display singular values
        //std::cout << "S = [";
        //for(int idxRow = 0 ; idxRow < rank ; ++idxRow)
        //    std::cout << S[idxRow] << " " ;
        //std::cout << "]" << std::endl;

151
        //// display rank
152
        std::cout << "rank SVD =" << rank << " (" << nbRows << "," << nbCols << ")" << std::endl;
153

154 155 156 157 158 159 160 161 162 163 164 165 166
        // Resize U and VT
        U  = new FReal[nbRows*rank]; // Call to computeSVD() copies block into U
        VT = new FReal[rank*nbCols];

        for(int idxRow = 0 ; idxRow < nbRows ; ++idxRow)
            for(int idxCol = 0 ; idxCol < rank ; ++idxCol)
                U[idxCol*nbRows+idxRow] = _U[idxCol*nbRows+idxRow];
        for(int idxRow = 0 ; idxRow < rank ; ++idxRow)
            for(int idxCol = 0 ; idxCol < nbCols ; ++idxCol)
                VT[idxCol*rank+idxRow] = _VT[idxCol*std::min(nbRows,nbCols)+idxRow]; 
        delete [] _U;           
        delete [] _VT;

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    };

    // dtor
    ~FSVDBlock(){
        // Free memory
        clear();
    };

    void clear(){
        nbRows = 0;
        nbCols = 0;
        level = 0;
        rank = 0;
        delete[] block;
        block = 0;
        delete[] U;
        U = 0;
        delete[] S;
        S = 0;
        delete[] VT;
        VT = 0;
    }

    void gemv(FReal res[], const FReal vec[], const FReal scale = FReal(1.)) const {
191

192
        //// Apply (dense) block
193 194 195
        //FReal* res_dense = new FReal[nbRows];
        //FBlas::copy(nbRows,res,res_dense);
        //FBlas::gemva(nbRows, nbCols, scale, const_cast<FReal*>(block), const_cast<FReal*>(vec), res_dense);
196 197
        
        // Apply low-rank block
198 199 200
        FReal* VTvec = new FReal[rank];
        FBlas::setzero(rank,VTvec);

201 202
        // Apply VT
        FBlas::gemv(rank, nbCols, scale, const_cast<FReal*>(VT), const_cast<FReal*>(vec), VTvec);
203

204 205
        // Apply S
        for(int idxS = 0 ; idxS < rank ; ++idxS)
206 207
            VTvec[idxS]*=S[idxS];    
        
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        // Apply U
        FBlas::gemva(nbRows, rank, scale, const_cast<FReal*>(U), const_cast<FReal*>(VTvec), res);

    }

    void gemm(FReal res[], const FReal mat[], const int nbRhs, const FReal scale = FReal(1.)) const {
        
        //// Apply (dense) block
        //FBlas::gemma(nbRows, nbCols, nbRhs, scale, const_cast<FReal*>(block), nbRows, const_cast<FReal*>(mat), nbCols, res, nbRows);

        // Apply low-rank block
        FReal* VTmat = new FReal[nbCols*nbRhs];
        // Apply VT
        FBlas::gemm(rank, nbCols, nbRhs, scale, const_cast<FReal*>(VT), rank, const_cast<FReal*>(mat), nbCols, VTmat, rank);
        // Apply S
        for(int idxRow = 0 ; idxRow < rank ; ++idxRow)
            for(int idxRhs = 0 ; idxRhs < nbRhs ; ++idxRhs)
                VTmat[idxRhs*rank+idxRow]*=S[idxRow];
        // Apply U
        FBlas::gemma(nbRows, rank, nbRhs, scale, const_cast<FReal*>(U), nbRows, const_cast<FReal*>(VTmat), rank, res, nbRows);

    }

231 232 233
    int getRank() const{
        return rank;
    }
234 235 236 237 238
};

#endif // FSVDBLOCK_HPP

// [--END--]