Commit e98870ea authored by BRAMAS Berenger's avatar BRAMAS Berenger

update hmat

parent 575e00ef
...@@ -55,6 +55,17 @@ public: ...@@ -55,6 +55,17 @@ public:
} }
}; };
void resize(const int inNbRow, const int inNbCol){
if(inNbRow != nbRows ||
inNbCol != nbCols){
clear();
nbRows = inNbRow;
nbCols = inNbCol;
block = new FReal[nbRows*nbCols];
}
memset(block, 0, sizeof(FReal)*nbRows*nbCols);
}
// dtor // dtor
~FDenseBlock(){ ~FDenseBlock(){
// Free memory // Free memory
...@@ -69,6 +80,26 @@ public: ...@@ -69,6 +80,26 @@ public:
block = 0; block = 0;
} }
int getNbRows() const{
return nbRows;
}
int getNbCols() const{
return nbCols;
}
FReal getValue(const int idxRow, const int idxCol) const{
return block[idxCol*nbRows+idxRow];
}
FReal& getValue(const int idxRow, const int idxCol) {
return block[idxCol*nbRows+idxRow];
}
void setValue(const int idxRow, const int idxCol, const FReal& val) {
block[idxCol*nbRows+idxRow] = val;
}
void gemv(FReal res[], const FReal vec[], const FReal scale = FReal(1.)) const { void gemv(FReal res[], const FReal vec[], const FReal scale = FReal(1.)) const {
FBlas::gemva(nbRows, nbCols, scale, const_cast<FReal*>(block), const_cast<FReal*>(vec), res); FBlas::gemva(nbRows, nbCols, scale, const_cast<FReal*>(block), const_cast<FReal*>(vec), res);
} }
......
...@@ -29,31 +29,31 @@ ...@@ -29,31 +29,31 @@
#include <memory> #include <memory>
template <class FReal, class RowBlockClass, class ColBlockClass, class CoreCellClass > template <class FReal, class CellClass >
class FBlockPMapping { class FBlockPMapping {
protected: protected:
struct CellNode { struct CellCNode {
FBlockDescriptor infos; FBlockDescriptor infos;
CoreCellClass cell; CellClass cell;
}; };
struct RowNode { struct RowUNode {
FBlockDescriptor infos; FBlockDescriptor infos;
RowBlockClass cell; CellClass cell;
}; };
struct ColNode { struct ColVNode {
FBlockDescriptor infos; FBlockDescriptor infos;
ColBlockClass cell; CellClass cell;
}; };
const int dim; const int dim;
const int nbPartitions; const int nbPartitions;
const int nbCells; const int nbCells;
CellNode* cells; CellCNode* cBlocks;
RowNode* rowBlocks; RowUNode* uRowBlocks;
ColNode* colBlocks; ColVNode* vColBlocks;
FBlockPMapping(const FBlockPMapping&) = delete; FBlockPMapping(const FBlockPMapping&) = delete;
FBlockPMapping& operator=(const FBlockPMapping&) = delete; FBlockPMapping& operator=(const FBlockPMapping&) = delete;
...@@ -63,7 +63,7 @@ public: ...@@ -63,7 +63,7 @@ public:
: dim(inDim), : dim(inDim),
nbPartitions(inNbPartitions), nbPartitions(inNbPartitions),
nbCells(inNbPartitions*inNbPartitions), nbCells(inNbPartitions*inNbPartitions),
cells(nullptr){ cBlocks(nullptr){
FAssertLF(nbPartitions <= inDim); FAssertLF(nbPartitions <= inDim);
FAssertLF(1 <= nbPartitions); FAssertLF(1 <= nbPartitions);
...@@ -73,41 +73,41 @@ public: ...@@ -73,41 +73,41 @@ public:
partitionsOffset[idxPart] = partitionsOffset[idxPart-1] + partitions[idxPart-1]; partitionsOffset[idxPart] = partitionsOffset[idxPart-1] + partitions[idxPart-1];
} }
cells = new CellNode[nbCells]; cBlocks = new CellCNode[nbCells];
for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){ for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){
for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){ for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){
cells[idxPartCol*nbPartitions + idxPartRow].infos.row = partitionsOffset[idxPartRow]; cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.row = partitionsOffset[idxPartRow];
cells[idxPartCol*nbPartitions + idxPartRow].infos.col = partitionsOffset[idxPartCol]; cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.col = partitionsOffset[idxPartCol];
cells[idxPartCol*nbPartitions + idxPartRow].infos.nbRows = partitions[idxPartRow]; cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.nbRows = partitions[idxPartRow];
cells[idxPartCol*nbPartitions + idxPartRow].infos.nbCols = partitions[idxPartCol]; cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.nbCols = partitions[idxPartCol];
cells[idxPartCol*nbPartitions + idxPartRow].infos.level = 0; cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.level = 0;
} }
} }
rowBlocks = new RowNode[nbPartitions]; uRowBlocks = new RowUNode[nbPartitions];
for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){ for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){
rowBlocks[idxPartRow].infos.row = partitionsOffset[idxPartRow]; uRowBlocks[idxPartRow].infos.row = partitionsOffset[idxPartRow];
rowBlocks[idxPartRow].infos.col = 0; uRowBlocks[idxPartRow].infos.col = 0;
rowBlocks[idxPartRow].infos.nbRows = partitions[idxPartRow]; uRowBlocks[idxPartRow].infos.nbRows = partitions[idxPartRow];
rowBlocks[idxPartRow].infos.nbCols = dim; uRowBlocks[idxPartRow].infos.nbCols = dim;
rowBlocks[idxPartRow].infos.level = 0; uRowBlocks[idxPartRow].infos.level = 0;
} }
colBlocks = new ColNode[nbPartitions]; vColBlocks = new ColVNode[nbPartitions];
for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){ for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){
colBlocks[idxPartCol].infos.row = 0; vColBlocks[idxPartCol].infos.row = 0;
colBlocks[idxPartCol].infos.col = partitionsOffset[idxPartCol]; vColBlocks[idxPartCol].infos.col = partitionsOffset[idxPartCol];
colBlocks[idxPartCol].infos.nbRows = dim; vColBlocks[idxPartCol].infos.nbRows = dim;
colBlocks[idxPartCol].infos.nbCols = partitions[idxPartCol]; vColBlocks[idxPartCol].infos.nbCols = partitions[idxPartCol];
colBlocks[idxPartCol].infos.level = 0; vColBlocks[idxPartCol].infos.level = 0;
} }
} }
~FBlockPMapping(){ ~FBlockPMapping(){
delete[] cells; delete[] cBlocks;
delete[] rowBlocks; delete[] uRowBlocks;
delete[] colBlocks; delete[] vColBlocks;
} }
int getNbBlocks() const { int getNbBlocks() const {
...@@ -116,74 +116,74 @@ public: ...@@ -116,74 +116,74 @@ public:
// Iterate blocks // Iterate blocks
CoreCellClass& getCell(const int idxRowPart, const int idxColPart){ CellClass& getCBlock(const int idxRowPart, const int idxColPart){
return cells[idxColPart*nbPartitions + idxRowPart].cell; return cBlocks[idxColPart*nbPartitions + idxRowPart].cell;
} }
const CoreCellClass& getCell(const int idxRowPart, const int idxColPart) const { const CellClass& getCBlock(const int idxRowPart, const int idxColPart) const {
return cells[idxColPart*nbPartitions + idxRowPart].cell; return cBlocks[idxColPart*nbPartitions + idxRowPart].cell;
} }
const FBlockDescriptor& getCellInfo(const int idxRowPart, const int idxColPart) const { const FBlockDescriptor& getCBlockInfo(const int idxRowPart, const int idxColPart) const {
return cells[idxColPart*nbPartitions + idxRowPart].infos; return cBlocks[idxColPart*nbPartitions + idxRowPart].infos;
} }
void forAllBlocksDescriptor(std::function<void(const FBlockDescriptor&)> callback){ void forAllCBlocksDescriptor(std::function<void(const FBlockDescriptor&)> callback){
for(int idxCell = 0 ; idxCell < nbCells ; ++idxCell){ for(int idxCell = 0 ; idxCell < nbCells ; ++idxCell){
callback(cells[idxCell].infos); callback(cBlocks[idxCell].infos);
} }
} }
void forAllCellBlocks(std::function<void(const FBlockDescriptor&, void forAllBlocks(std::function<void(const FBlockDescriptor&,
RowBlockClass&, CoreCellClass&, ColBlockClass&)> callback){ CellClass&, CellClass&, CellClass&)> callback){
for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){ for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){
for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){ for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){
callback(cells[idxPartCol*nbPartitions + idxPartRow].infos, callback(cBlocks[idxPartCol*nbPartitions + idxPartRow].infos,
cells[idxPartCol*nbPartitions + idxPartRow].cell, cBlocks[idxPartCol*nbPartitions + idxPartRow].cell,
rowBlocks[idxPartRow].cell, uRowBlocks[idxPartRow].cell,
colBlocks[idxPartCol].cell); vColBlocks[idxPartCol].cell);
} }
} }
} }
// Iterate row blocks // Iterate row blocks
RowBlockClass& getRowCell(const int idxRowPart){ CellClass& getUBlock(const int idxRowPart){
return rowBlocks[idxRowPart].cell; return uRowBlocks[idxRowPart].cell;
} }
const RowBlockClass& getRowCell(const int idxRowPart) const { const CellClass& getUBlock(const int idxRowPart) const {
return rowBlocks[idxRowPart].cell; return uRowBlocks[idxRowPart].cell;
} }
const FBlockDescriptor& getRowCellInfo(const int idxRowPart) const { const FBlockDescriptor& getUBlockInfo(const int idxRowPart) const {
return rowBlocks[idxRowPart].infos; return uRowBlocks[idxRowPart].infos;
} }
// Iterate col blocks // Iterate col blocks
ColBlockClass& getColCell(const int idxColPart){ CellClass& getVBlock(const int idxColPart){
return colBlocks[idxColPart].cell; return vColBlocks[idxColPart].cell;
} }
const ColBlockClass& getColCell(const int idxColPart) const { const CellClass& getVBlock(const int idxColPart) const {
return colBlocks[idxColPart].cell; return vColBlocks[idxColPart].cell;
} }
const FBlockDescriptor& getColCellInfo(const int idxColPart) const { const FBlockDescriptor& getVBlockInfo(const int idxColPart) const {
return colBlocks[idxColPart].infos; return vColBlocks[idxColPart].infos;
} }
// Operations // Operations
void gemv(FReal res[], const FReal vec[]) const { void gemv(FReal res[], const FReal vec[]) const {
for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){ for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){
for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){ for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){
// &res[cells[idxPartCol*nbPartitions + idxPartRow].infos.row], // &res[cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.row],
// &vec[cells[idxPartCol*nbPartitions + idxPartRow].infos.col]) // &vec[cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.col])
// cells[idxPartCol*nbPartitions + idxPartRow].cell, // cBlocks[idxPartCol*nbPartitions + idxPartRow].cell,
// rowBlocks[idxPartRow].cell, // uRowBlocks[idxPartRow].cell,
// colBlocks[idxPartCol].cell; // vColBlocks[idxPartCol].cell;
} }
} }
} }
...@@ -191,11 +191,11 @@ public: ...@@ -191,11 +191,11 @@ public:
void gemm(FReal res[], const FReal mat[], const int nbRhs) const { void gemm(FReal res[], const FReal mat[], const int nbRhs) const {
for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){ for(int idxPartCol = 0 ; idxPartCol < nbPartitions ; ++idxPartCol){
for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){ for(int idxPartRow = 0 ; idxPartRow < nbPartitions ; ++idxPartRow){
// &res[cells[idxPartCol*nbPartitions + idxPartRow].infos.row], // &res[cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.row],
// &vec[cells[idxPartCol*nbPartitions + idxPartRow].infos.col]) // &vec[cBlocks[idxPartCol*nbPartitions + idxPartRow].infos.col])
// cells[idxPartCol*nbPartitions + idxPartRow].cell, // cBlocks[idxPartCol*nbPartitions + idxPartRow].cell,
// rowBlocks[idxPartRow].cell, // uRowBlocks[idxPartRow].cell,
// colBlocks[idxPartCol].cell; // vColBlocks[idxPartCol].cell;
// nbRhs, dim // nbRhs, dim
} }
} }
......
...@@ -81,7 +81,7 @@ int main(int argc, char** argv){ ...@@ -81,7 +81,7 @@ int main(int argc, char** argv){
{ {
typedef FDenseBlock<FReal> CellClass; typedef FDenseBlock<FReal> CellClass;
typedef FBlockPMapping<FReal, CellClass, CellClass, CellClass> GridClass; typedef FBlockPMapping<FReal, CellClass> GridClass;
std::unique_ptr<int[]> partitions(new int[nbPartitions]); std::unique_ptr<int[]> partitions(new int[nbPartitions]);
{ {
...@@ -96,20 +96,64 @@ int main(int argc, char** argv){ ...@@ -96,20 +96,64 @@ int main(int argc, char** argv){
GridClass grid(dim, partitions.get(), nbPartitions); GridClass grid(dim, partitions.get(), nbPartitions);
// We iterate on the blocks // We iterate on the blocks
// V blocks cover all the rows, but only some columns (based on the clustering)
for(int idxColBlock = 0 ; idxColBlock < nbPartitions ; ++idxColBlock){ for(int idxColBlock = 0 ; idxColBlock < nbPartitions ; ++idxColBlock){
const MatrixClass::BlockDescriptor colBlock = matrix.getBlock(grid.getColCellInfo(idxColBlock)); const MatrixClass::BlockDescriptor colBlock = matrix.getBlock(grid.getVBlockInfo(idxColBlock));
int rj = -1;
// Store the result in grid.getColCell(idxColBlock) /// TODO HERE
/// Compute rj, and the resulting Vj blocks,
/// use the colBlock (some or all of its values)
/// TODO END
// Store the result in grid.getVBlock(idxColBlock)
CellClass& Vj = grid.getVBlock(idxColBlock);
Vj.resize(rj, colBlock.getNbCols());
for(int idxRow = 0 ; idxRow < Vj.getNbRows() ; ++idxRow){
for(int idxCol = 0 ; idxCol < Vj.getNbCols() ; ++idxCol){
/// TODO HERE
/// Fill Vj with the result
Vj.setValue(idxRow, idxCol, -1);
/// TODO END
}
} }
}
// U blocks cover all the columns, but only some rows (based on the clustering)
for(int idxRowBlock = 0 ; idxRowBlock < nbPartitions ; ++idxRowBlock){ for(int idxRowBlock = 0 ; idxRowBlock < nbPartitions ; ++idxRowBlock){
const MatrixClass::BlockDescriptor rowBlock = matrix.getBlock(grid.getRowCellInfo(idxRowBlock)); const MatrixClass::BlockDescriptor rowBlock = matrix.getBlock(grid.getUBlockInfo(idxRowBlock));
int ri = -1;
// Store the result in grid.getRowCell(idxRowBlock) /// TODO HERE
/// Compute ri, and the resulting Ui blocks
/// use the rowBlock (some or all of its values)
/// TODO END
// Store the result in grid.getUBlock(idxRowBlock)
CellClass& Ui = grid.getUBlock(idxRowBlock);
Ui.resize(rowBlock.getNbRows(), ri);
for(int idxRow = 0 ; idxRow < Ui.getNbRows() ; ++idxRow){
for(int idxCol = 0 ; idxCol < Ui.getNbCols() ; ++idxCol){
/// TODO HERE
/// Fill Vj with the result
Ui.setValue(idxRow, idxCol, -1);
/// TODO END
}
}
} }
// Build the core part // Build the core part
for(int idxColBlock = 0 ; idxColBlock < nbPartitions ; ++idxColBlock){ for(int idxColBlock = 0 ; idxColBlock < nbPartitions ; ++idxColBlock){
for(int idxRowBlock = 0 ; idxRowBlock < nbPartitions ; ++idxRowBlock){ for(int idxRowBlock = 0 ; idxRowBlock < nbPartitions ; ++idxRowBlock){
// Store the result in grid.getCell(idxRowBlock, idxColBlock) const CellClass& Ui = grid.getUBlock(idxRowBlock);
const CellClass& Vj = grid.getVBlock(idxColBlock);
// Store the result in grid.getCBlock(idxRowBlock, idxColBlock)
CellClass& Cij = grid.getCBlock(idxRowBlock, idxColBlock);
Cij.resize(Vj.getNbRows(), Ui.getNbCols());
for(int idxRow = 0 ; idxRow < Cij.getNbRows() ; ++idxRow){
for(int idxCol = 0 ; idxCol < Cij.getNbCols() ; ++idxCol){
/// TODO HERE
/// Fill Vj with the result
Cij.setValue(idxRow, idxCol, -1);
/// TODO END
}
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment