diff --git a/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.cpp b/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.cpp index 9dc039743902f79fc2d1e5efdc9cb9a3b69a5262..5be62b201c1377affaea31c032bfe651eea10d96 100644 --- a/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.cpp +++ b/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.cpp @@ -29,10 +29,10 @@ OpenViBE::int32 OpenViBEPlugins::Classification::LDAClassificationCompare(OpenVi { //We first need to find the best classification of each. OpenViBE::float64* l_pClassificationValueBuffer = rFirstClassificationValue.getBuffer(); - OpenViBE::float64 l_f64MaxFirst = *(std::max_element(l_pClassificationValueBuffer, l_pClassificationValueBuffer+rFirstClassificationValue.getBufferElementCount())); + const OpenViBE::float64 l_f64MaxFirst = *(std::max_element(l_pClassificationValueBuffer, l_pClassificationValueBuffer+rFirstClassificationValue.getBufferElementCount())); l_pClassificationValueBuffer = rSecondClassificationValue.getBuffer(); - OpenViBE::float64 l_f64MaxSecond = *(std::max_element(l_pClassificationValueBuffer, l_pClassificationValueBuffer+rSecondClassificationValue.getBufferElementCount())); + const OpenViBE::float64 l_f64MaxSecond = *(std::max_element(l_pClassificationValueBuffer, l_pClassificationValueBuffer+rSecondClassificationValue.getBufferElementCount())); //Then we just compared them if(!ov_float_equal(l_f64MaxFirst, l_f64MaxSecond)) @@ -106,10 +106,10 @@ boolean CAlgorithmClassifierLDA::train(const IFeatureVectorSet& rFeatureVectorSe this->initializeExtraParameterMechanism(); //We need to clear list because a instance of this class should support more that one training. - m_oLabelList.clear(); - m_oComputationHelperList.clear(); + m_vLabelList.clear(); + m_vDiscriminantFunctions.clear(); - boolean l_bUseShrinkage = this->getBooleanParameter(OVP_Algorithm_ClassifierLDA_InputParameterId_UseShrinkage); + const boolean l_bUseShrinkage = this->getBooleanParameter(OVP_Algorithm_ClassifierLDA_InputParameterId_UseShrinkage); boolean l_pDiagonalCov; if(l_bUseShrinkage) @@ -147,37 +147,37 @@ boolean CAlgorithmClassifierLDA::train(const IFeatureVectorSet& rFeatureVectorSe } // Count the classes - std::map < float64, uint32 > l_vClassLabels; + std::map < float64, uint32 > l_vClassCounts; for(uint32 i=0; i<rFeatureVectorSet.getFeatureVectorCount(); i++) { - l_vClassLabels[rFeatureVectorSet[i].getLabel()]++; + l_vClassCounts[rFeatureVectorSet[i].getLabel()]++; } - const uint32 l_ui32nClasses = l_vClassLabels.size(); + const uint32 l_ui32nClasses = l_vClassCounts.size(); // Get class labels - for(std::map < float64, uint32 >::iterator iter = l_vClassLabels.begin() ; iter != l_vClassLabels.end() ; ++iter) + for(std::map < float64, uint32 >::iterator iter = l_vClassCounts.begin() ; iter != l_vClassCounts.end() ; ++iter) { - m_oLabelList.push_back(iter->first); - m_oComputationHelperList.push_back(CAlgorithmLDADiscriminantFunction()); + m_vLabelList.push_back(iter->first); + m_vDiscriminantFunctions.push_back(CAlgorithmLDADiscriminantFunction()); } - // Get regularized covariances of all the classes - VectorXd* l_aMean = new VectorXd[l_ui32nClasses]; + // Per-class means and a global covariance are used to form the LDA model + MatrixXd* l_oPerClassMeans = new MatrixXd[l_ui32nClasses]; MatrixXd l_oGlobalCov = MatrixXd::Zero(l_ui32nCols,l_ui32nCols); + // We need the means per class for(uint32 l_ui32classIdx=0;l_ui32classIdx<l_ui32nClasses;l_ui32classIdx++) { - MatrixXd l_aCov; - - const float64 l_f64Label = m_oLabelList[l_ui32classIdx]; - const uint32 l_ui32nExamplesInClass = l_vClassLabels[l_f64Label]; - - // Copy all the data of the class to a feature matrix - ip_pFeatureVectorSet->setDimensionCount(2); - ip_pFeatureVectorSet->setDimensionSize(0, l_ui32nExamplesInClass); - ip_pFeatureVectorSet->setDimensionSize(1, l_ui32nCols); - float64 *l_pBuffer = ip_pFeatureVectorSet->getBuffer(); + const float64 l_f64Label = m_vLabelList[l_ui32classIdx]; + const uint32 l_ui32nExamplesInClass = l_vClassCounts[l_f64Label]; + + // Copy all the data of the class to a matrix + CMatrix l_oClassData; + l_oClassData.setDimensionCount(2); + l_oClassData.setDimensionSize(0, l_ui32nExamplesInClass); + l_oClassData.setDimensionSize(1, l_ui32nCols); + float64 *l_pBuffer = l_oClassData.getBuffer(); for(uint32 i=0;i<l_ui32nRows;i++) { if(rFeatureVectorSet[i].getLabel() == l_f64Label) @@ -187,24 +187,39 @@ boolean CAlgorithmClassifierLDA::train(const IFeatureVectorSet& rFeatureVectorSe } } - // Compute mean and cov + // Get the mean out of it + Map<MatrixXdRowMajor> l_oDataMapper(l_oClassData.getBuffer(), l_ui32nExamplesInClass, l_ui32nCols); + const MatrixXd l_oClassMean = l_oDataMapper.colwise().mean().transpose(); + l_oPerClassMeans[l_ui32classIdx] = l_oClassMean; + } + + // We need a global covariance, use the regularized cov algorithm + { + ip_pFeatureVectorSet->setDimensionCount(2); + ip_pFeatureVectorSet->setDimensionSize(0, l_ui32nRows); + ip_pFeatureVectorSet->setDimensionSize(1, l_ui32nCols); + float64 *l_pBuffer = ip_pFeatureVectorSet->getBuffer(); + + // Insert all data as the input of the cov algorithm + for(uint32 i=0;i<l_ui32nRows;i++) + { + System::Memory::copy(l_pBuffer, rFeatureVectorSet[i].getBuffer(), l_ui32nCols*sizeof(float64)); + l_pBuffer += l_ui32nCols; + } + + // Compute cov if(!m_pCovarianceAlgorithm->process()) { - this->getLogManager() << LogLevel_Error << "Covariance computation failed for class " << l_ui32classIdx << " ("<< l_f64Label << ")\n"; + this->getLogManager() << LogLevel_Error << "Global covariance computation failed\n"; return false; } + // Get the results from the cov algorithm - Map<VectorXd> l_oMeanMapper(op_pMean->getBuffer(), l_ui32nCols); - l_aMean[l_ui32classIdx] = l_oMeanMapper; Map<MatrixXdRowMajor> l_oCovMapper(op_pCovarianceMatrix->getBuffer(), l_ui32nCols, l_ui32nCols); - l_aCov = l_oCovMapper; - - l_oGlobalCov += l_aCov; - - //dumpMatrix(this->getLogManager(), l_aMean[l_ui32classIdx], "Mean"); - //dumpMatrix(this->getLogManager(), l_aCov[l_ui32classIdx], "Shrinked cov"); + l_oGlobalCov = l_oCovMapper; } - l_oGlobalCov /= (double)l_ui32nClasses; + //dumpMatrix(this->getLogManager(), l_aMean[l_ui32classIdx], "Mean"); + //dumpMatrix(this->getLogManager(), l_oGlobalCov, "Shrinked cov"); if(l_pDiagonalCov) { @@ -233,24 +248,34 @@ boolean CAlgorithmClassifierLDA::train(const IFeatureVectorSet& rFeatureVectorSe //We send the bias and the weight of each class to ComputationHelper for(size_t i = 0 ; i < getClassCount() ; ++i) { - VectorXd l_oWeight = (l_oGlobalCovInv * l_aMean[i]); - const MatrixXd l_oInter = -0.5 * l_aMean[i].transpose() * l_oGlobalCovInv * l_aMean[i]; - float64 l_f64Bias = l_oInter(0,0) + std::log(m_oLabelList[i]/rFeatureVectorSet.getFeatureVectorCount()); + const float64 l_f64ExamplesInClass = l_vClassCounts[m_vLabelList[i]]; + const uint32 l_ui32TotalExamples = rFeatureVectorSet.getFeatureVectorCount(); + + // This formula e.g. in Hastie, Tibshirani & Friedman: "Elements...", 2nd ed., p. 109 + const VectorXd l_oWeight = (l_oGlobalCovInv * l_oPerClassMeans[i]); + const MatrixXd l_oInter = -0.5 * l_oPerClassMeans[i].transpose() * l_oGlobalCovInv * l_oPerClassMeans[i]; + const float64 l_f64Bias = l_oInter(0,0) + std::log(l_f64ExamplesInClass/l_ui32TotalExamples); + + // this->getLogManager() << LogLevel_Info << "Bias for " << i << " is " << l_f64Bias << ", from " << l_f64ExamplesInClass / l_ui32TotalExamples + // << ", " << stuffInClass << "/" << featVectors + // << "\n"; + // dumpMatrix(this->getLogManager(), l_oPerClassMeans[i], "Means"); + + m_vDiscriminantFunctions[i].setWeight(l_oWeight); + m_vDiscriminantFunctions[i].setBias(l_f64Bias); - m_oComputationHelperList[i].setWeight(l_oWeight); - m_oComputationHelperList[i].setBias(l_f64Bias); } m_ui32NumCols = l_ui32nCols; // Debug output - /*dumpMatrix(this->getLogManager(), l_oGlobalCov, "Global cov"); - dumpMatrix(this->getLogManager(), l_oEigenValues, "Eigenvalues"); - dumpMatrix(this->getLogManager(), l_oEigenSolver.eigenvectors(), "Eigenvectors"); - dumpMatrix(this->getLogManager(), l_oGlobalCovInv, "Global cov inverse"); - dumpMatrix(this->getLogManager(), m_oCoefficients, "Hyperplane weights");*/ + //dumpMatrix(this->getLogManager(), l_oGlobalCov, "Global cov"); + //dumpMatrix(this->getLogManager(), l_oEigenValues, "Eigenvalues"); + //dumpMatrix(this->getLogManager(), l_oEigenSolver.eigenvectors(), "Eigenvectors"); + //dumpMatrix(this->getLogManager(), l_oGlobalCovInv, "Global cov inverse"); + //dumpMatrix(this->getLogManager(), m_oCoefficients, "Hyperplane weights"); - delete[] l_aMean; + delete[] l_oPerClassMeans; return true; } @@ -268,7 +293,7 @@ boolean CAlgorithmClassifierLDA::classify(const IFeatureVector& rFeatureVector, return false; } - // Catenate 1.0 to match the bias term + // Catenate 1.0 to match the bias term MatrixXd l_oWeights(1, l_ui32nColsWithBiasTerm); l_oWeights(0,0) = 1.0; l_oWeights.block(0,1,1,l_ui32nColsWithBiasTerm-1) = l_oFeatureVec; @@ -286,17 +311,29 @@ boolean CAlgorithmClassifierLDA::classify(const IFeatureVector& rFeatureVector, if(l_f64P1 >= 0.5) { - rf64Class=m_oLabelList[0]; + rf64Class=m_vLabelList[0]; } else { - rf64Class=m_oLabelList[1]; + rf64Class=m_vLabelList[1]; } } else { + if(m_vDiscriminantFunctions.size() == 0) + { + this->getLogManager() << LogLevel_Error << "LDA discriminant function list is empty\n"; + return false; + } + + if(rFeatureVector.getSize() != m_vDiscriminantFunctions[0].getWeightVectorSize()) + { + this->getLogManager() << LogLevel_Error << "Classifier expected " << m_vDiscriminantFunctions[0].getWeightVectorSize() << " features, got " << rFeatureVector.getSize() << "\n"; + return false; + } + const Map<VectorXd> l_oFeatureVec(const_cast<float64*>(rFeatureVector.getBuffer()), rFeatureVector.getSize()); - VectorXd l_oWeights = l_oFeatureVec; + const VectorXd l_oWeights = l_oFeatureVec; const uint32 l_ui32ClassCount = getClassCount(); float64 *l_pValueArray = new float64[l_ui32ClassCount]; @@ -304,7 +341,7 @@ boolean CAlgorithmClassifierLDA::classify(const IFeatureVector& rFeatureVector, //We ask for all computation helper to give the corresponding class value for(size_t i = 0; i < l_ui32ClassCount ; ++i) { - l_pValueArray[i] = m_oComputationHelperList[i].getValue(l_oWeights); + l_pValueArray[i] = m_vDiscriminantFunctions[i].getValue(l_oWeights); } //p(Ck | x) = exp(ak) / sum[j](exp (aj)) @@ -335,7 +372,7 @@ boolean CAlgorithmClassifierLDA::classify(const IFeatureVector& rFeatureVector, rClassificationValues[i] = l_pValueArray[i]; rProbabilityValue[i] = l_pProbabilityValue[i]; } - rf64Class = m_oLabelList[l_ui32ClassIndex]; + rf64Class = m_vLabelList[l_ui32ClassIndex]; delete l_pValueArray; delete l_pProbabilityValue; @@ -345,7 +382,7 @@ boolean CAlgorithmClassifierLDA::classify(const IFeatureVector& rFeatureVector, uint32 CAlgorithmClassifierLDA::getClassCount() { - return m_oLabelList.size(); + return m_vLabelList.size(); } XML::IXMLNode* CAlgorithmClassifierLDA::saveConfiguration(void) @@ -358,14 +395,14 @@ XML::IXMLNode* CAlgorithmClassifierLDA::saveConfiguration(void) for(size_t i = 0; i< getClassCount() ; ++i) { - l_sClasses << m_oLabelList[i] << " "; + l_sClasses << m_vLabelList[i] << " "; } //Only new version should be recorded so we don't need to test XML::IXMLNode *l_pHelpersConfiguration = XML::createNode(c_sComputationHelpersConfigurationNode); - for(size_t i = 0; i < m_oComputationHelperList.size() ; ++i) + for(size_t i = 0; i < m_vDiscriminantFunctions.size() ; ++i) { - l_pHelpersConfiguration->addChild(m_oComputationHelperList[i].getConfiguration()); + l_pHelpersConfiguration->addChild(m_vDiscriminantFunctions[i].getConfiguration()); } XML::IXMLNode *l_pTempNode = XML::createNode(c_sClassesNodeName); @@ -398,7 +435,7 @@ boolean CAlgorithmClassifierLDA::loadConfiguration(XML::IXMLNode *pConfiguration m_bv1Classification = true; } - m_oLabelList.clear(); + m_vLabelList.clear(); XML::IXMLNode* l_pTempNode; @@ -441,8 +478,8 @@ boolean CAlgorithmClassifierLDA::loadConfiguration(XML::IXMLNode *pConfiguration for(size_t i = 0 ; i < l_pConfigsNode->getChildCount() ; ++i) { - m_oComputationHelperList.push_back(CAlgorithmLDADiscriminantFunction()); - m_oComputationHelperList[i].loadConfiguration(l_pConfigsNode->getChild(i)); + m_vDiscriminantFunctions.push_back(CAlgorithmLDADiscriminantFunction()); + m_vDiscriminantFunctions[i].loadConfiguration(l_pConfigsNode->getChild(i)); } } return true; @@ -454,7 +491,7 @@ void CAlgorithmClassifierLDA::loadClassesFromNode(XML::IXMLNode *pNode) float64 l_f64Temp; while(l_sData >> l_f64Temp) { - m_oLabelList.push_back(l_f64Temp); + m_vLabelList.push_back(l_f64Temp); } } diff --git a/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.h b/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.h index cf623119688aaa3bc4a512c8d2139f478639e9f1..3aa2038ddf33483f759b5cfbfafbb7792b5bc00b 100644 --- a/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.h +++ b/plugins/processing/classification/src/algorithms/ovpCAlgorithmClassifierLDA.h @@ -55,8 +55,8 @@ namespace OpenViBEPlugins // Debug method. Prints the matrix to the logManager. May be disabled in implementation. void dumpMatrix(OpenViBE::Kernel::ILogManager& pMgr, const MatrixXdRowMajor& mat, const OpenViBE::CString& desc); - std::vector < OpenViBE::float64 > m_oLabelList; - std::vector < CAlgorithmLDADiscriminantFunction > m_oComputationHelperList; + std::vector < OpenViBE::float64 > m_vLabelList; + std::vector < CAlgorithmLDADiscriminantFunction > m_vDiscriminantFunctions; Eigen::MatrixXd m_oCoefficients; Eigen::MatrixXd m_oWeights; diff --git a/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.cpp b/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.cpp index 8ee295b2fe6e8c1424e12f75a9c7a4603bf9369f..9858b353a263641cb8b6f7dbd14dd3ec088a7756 100644 --- a/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.cpp +++ b/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.cpp @@ -25,7 +25,7 @@ CAlgorithmLDADiscriminantFunction::CAlgorithmLDADiscriminantFunction():m_f64Bias { } -void CAlgorithmLDADiscriminantFunction::setWeight(VectorXd &rWeigth) +void CAlgorithmLDADiscriminantFunction::setWeight(const VectorXd &rWeigth) { m_oWeight = rWeigth; } @@ -35,7 +35,7 @@ void CAlgorithmLDADiscriminantFunction::setBias(float64 f64Bias) m_f64Bias = f64Bias; } -float64 CAlgorithmLDADiscriminantFunction::getValue(VectorXd &rFeatureVector) +float64 CAlgorithmLDADiscriminantFunction::getValue(const VectorXd &rFeatureVector) { return (m_oWeight.transpose() * rFeatureVector)(0) + m_f64Bias; } @@ -45,7 +45,7 @@ uint32 CAlgorithmLDADiscriminantFunction::getWeightVectorSize() return m_oWeight.size(); } -boolean CAlgorithmLDADiscriminantFunction::loadConfiguration(XML::IXMLNode *pConfiguration) +boolean CAlgorithmLDADiscriminantFunction::loadConfiguration(const XML::IXMLNode *pConfiguration) { std::stringstream l_sBias(pConfiguration->getChildByName(c_sBiasNodeName)->getPCData()); l_sBias >> m_f64Bias; diff --git a/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.h b/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.h index 6d0229c194dff734b7aec81706c286842ff236ea..1a2f9abb45400dc1c7dc2efc7373934e4408c42f 100644 --- a/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.h +++ b/plugins/processing/classification/src/algorithms/ovpCAlgorithmLDADiscriminantFunction.h @@ -21,14 +21,14 @@ namespace OpenViBEPlugins public: CAlgorithmLDADiscriminantFunction(); - void setWeight(Eigen::VectorXd &rWeigth); + void setWeight(const Eigen::VectorXd &rWeigth); void setBias(OpenViBE::float64 f64Bias); //Return the class membership of the feature vector - OpenViBE::float64 getValue(Eigen::VectorXd &rFeatureVector); + OpenViBE::float64 getValue(const Eigen::VectorXd &rFeatureVector); OpenViBE::uint32 getWeightVectorSize(void); - OpenViBE::boolean loadConfiguration(XML::IXMLNode* pConfiguration); + OpenViBE::boolean loadConfiguration(const XML::IXMLNode* pConfiguration); XML::IXMLNode* getConfiguration(void); private: