Commit ddc5c64d authored by Jussi Lindgren's avatar Jussi Lindgren
Browse files

Plugins: Classifier Processor can now reload the model on receiving a stimulation

- Additionally, refactored the code to use codec interface from toolkit
parent 87bb086f
......@@ -15,32 +15,21 @@ using namespace OpenViBEPlugins;
using namespace OpenViBEPlugins::Classification;
using namespace std;
boolean CBoxAlgorithmClassifierProcessor::initialize(void)
boolean CBoxAlgorithmClassifierProcessor::loadClassifier(const char* sFilename)
{
m_pFeaturesDecoder = NULL;
m_pLabelsEncoder = NULL;
m_pClassificationStateEncoder = NULL;
m_pProbabilityValues=NULL;
m_pClassifier = NULL;
IBox& l_rStaticBoxContext=this->getStaticBoxContext();
//First of all, let's get the XML file for configuration
CString l_sConfigurationFilename;
l_rStaticBoxContext.getSettingValue(0, l_sConfigurationFilename);
if(l_sConfigurationFilename == CString(""))
if(m_pClassifier)
{
this->getLogManager() << LogLevel_Error << "You need to specify a classifier .xml for the box (use Classifier Trainer to create one)\n";
return false;
m_pClassifier->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pClassifier);
m_pClassifier = NULL;
}
XML::IXMLHandler *l_pHandler = XML::createXMLHandler();
XML::IXMLNode *l_pRootNode = l_pHandler->parseFile(l_sConfigurationFilename.toASCIIString());
XML::IXMLNode *l_pRootNode = l_pHandler->parseFile(sFilename);
if(!l_pRootNode)
{
this->getLogManager() << LogLevel_Error << "Unable to get root node from [" << l_sConfigurationFilename << "]\n";
this->getLogManager() << LogLevel_Error << "Unable to get root node from [" << sFilename << "]\n";
return false;
}
......@@ -87,7 +76,7 @@ boolean CBoxAlgorithmClassifierProcessor::initialize(void)
//If the algorithm is still unknown, that means that we face an error
if(l_oAlgorithmClassIdentifier==OV_UndefinedIdentifier)
{
this->getLogManager() << LogLevel_Error << "Couldn't restore a classifier from the file [" << l_sConfigurationFilename << "].\n";
this->getLogManager() << LogLevel_Error << "Couldn't restore a classifier from the file [" << sFilename << "].\n";
return false;
}
}
......@@ -125,15 +114,6 @@ boolean CBoxAlgorithmClassifierProcessor::initialize(void)
this->getLogManager() << LogLevel_Warning << "The configuration file had no node " << c_sStimulationsNodeName << ". Trouble may appear later.\n";
}
m_pFeaturesDecoder=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_FeatureVectorStreamDecoder));
m_pFeaturesDecoder->initialize();
m_pLabelsEncoder=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_StimulationStreamEncoder));
m_pLabelsEncoder->initialize();
m_pClassificationStateEncoder=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_StreamedMatrixStreamEncoder));
m_pClassificationStateEncoder->initialize();
const CIdentifier l_oClassifierAlgorithmIdentifier = this->getAlgorithmManager().createAlgorithm(l_oAlgorithmClassIdentifier);
if(l_oClassifierAlgorithmIdentifier == OV_UndefinedIdentifier)
{
......@@ -145,14 +125,14 @@ boolean CBoxAlgorithmClassifierProcessor::initialize(void)
m_pClassifier=&this->getAlgorithmManager().getAlgorithm(l_oClassifierAlgorithmIdentifier);
m_pClassifier->initialize();
m_pProbabilityValues=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_StreamedMatrixStreamEncoder));
m_pProbabilityValues->initialize();
// Connect the params to the new classifier
TParameterHandler < OpenViBE::IMatrix* > ip_oFeatureVector = m_pClassifier->getInputParameter(OVTK_Algorithm_Classifier_InputParameterId_FeatureVector);
ip_oFeatureVector.setReferenceTarget(m_oFeaturesDecoder.getOutputMatrix());
m_pClassifier->getInputParameter(OVTK_Algorithm_Classifier_InputParameterId_FeatureVector)->setReferenceTarget(m_pFeaturesDecoder->getOutputParameter(OVP_GD_Algorithm_FeatureVectorStreamDecoder_OutputParameterId_Matrix));
m_pClassificationStateEncoder->getInputParameter(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputParameterId_Matrix)->setReferenceTarget(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_ClassificationValues));
m_pProbabilityValues->getInputParameter(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputParameterId_Matrix)->setReferenceTarget(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_ProbabilityValues));
m_oClassificationStateEncoder.getInputMatrix().setReferenceTarget(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_ClassificationValues));
m_oProbabilityValuesEncoder.getInputMatrix().setReferenceTarget(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_ProbabilityValues));
// note: labelsencoder cannot be directly bound here as the classifier returns a float, but we need to output a stimulation
TParameterHandler < XML::IXMLNode* > ip_pClassificationConfiguration(m_pClassifier->getInputParameter(OVTK_Algorithm_Classifier_InputParameterId_Configuration));
ip_pClassificationConfiguration = l_pRootNode->getChildByName(c_sClassifierRoot)->getChild(0);
......@@ -164,47 +144,57 @@ boolean CBoxAlgorithmClassifierProcessor::initialize(void)
l_pRootNode->release();
l_pHandler->release();
m_bOutputHeaderSent=false;
return true;
}
boolean CBoxAlgorithmClassifierProcessor::uninitialize(void)
boolean CBoxAlgorithmClassifierProcessor::initialize(void)
{
if(m_pClassifier)
{
m_pClassifier->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pClassifier);
m_pClassifier = NULL;
}
m_pClassifier = NULL;
if(m_pClassificationStateEncoder)
{
m_pClassificationStateEncoder->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pClassificationStateEncoder);
m_pClassificationStateEncoder = NULL;
}
IBox& l_rStaticBoxContext=this->getStaticBoxContext();
//First of all, let's get the XML file for configuration
CString l_sConfigurationFilename;
l_rStaticBoxContext.getSettingValue(0, l_sConfigurationFilename);
if(m_pProbabilityValues)
if(l_sConfigurationFilename == CString(""))
{
m_pProbabilityValues->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pProbabilityValues);
m_pProbabilityValues = NULL;
this->getLogManager() << LogLevel_Error << "You need to specify a classifier .xml for the box (use Classifier Trainer to create one)\n";
return false;
}
if(m_pLabelsEncoder)
m_oFeaturesDecoder.initialize(*this,0);
m_oStimulationDecoder.initialize(*this, 1);
m_oLabelsEncoder.initialize(*this, 0);
m_oClassificationStateEncoder.initialize(*this, 1);
m_oProbabilityValuesEncoder.initialize(*this, 2);
if(!loadClassifier(l_sConfigurationFilename.toASCIIString()))
{
m_pLabelsEncoder->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pLabelsEncoder);
m_pLabelsEncoder = NULL;
return false;
}
if(m_pFeaturesDecoder)
m_bOutputHeaderSent=false;
return true;
}
boolean CBoxAlgorithmClassifierProcessor::uninitialize(void)
{
if(m_pClassifier)
{
m_pFeaturesDecoder->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pFeaturesDecoder);
m_pFeaturesDecoder = NULL;
m_pClassifier->uninitialize();
this->getAlgorithmManager().releaseAlgorithm(*m_pClassifier);
m_pClassifier = NULL;
}
m_oClassificationStateEncoder.uninitialize();
m_oProbabilityValuesEncoder.uninitialize();
m_oLabelsEncoder.uninitialize();
m_oFeaturesDecoder.uninitialize();
m_oStimulationDecoder.uninitialize();
return true;
}
......@@ -221,29 +211,16 @@ boolean CBoxAlgorithmClassifierProcessor::process(void)
for(uint32 i=0; i<l_rDynamicBoxContext.getInputChunkCount(0); i++)
{
uint64 l_ui64StartTime=l_rDynamicBoxContext.getInputChunkStartTime(0, i);
uint64 l_ui64EndTime=l_rDynamicBoxContext.getInputChunkEndTime(0, i);
TParameterHandler < const IMemoryBuffer* > ip_pFeatureVectorMemoryBuffer(m_pFeaturesDecoder->getInputParameter(OVP_GD_Algorithm_FeatureVectorStreamDecoder_InputParameterId_MemoryBufferToDecode));
TParameterHandler < IMemoryBuffer* > op_pLabelsMemoryBuffer(m_pLabelsEncoder->getOutputParameter(OVP_GD_Algorithm_StimulationStreamEncoder_OutputParameterId_EncodedMemoryBuffer));
TParameterHandler < IMemoryBuffer* > op_pClassificationStateMemoryBuffer(m_pClassificationStateEncoder->getOutputParameter(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_OutputParameterId_EncodedMemoryBuffer));
TParameterHandler < IMemoryBuffer* > op_pProbabilityValues(m_pProbabilityValues->getOutputParameter(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_OutputParameterId_EncodedMemoryBuffer));
TParameterHandler < IStimulationSet* > ip_pLabelsStimulationSet(m_pLabelsEncoder->getInputParameter(OVP_GD_Algorithm_StimulationStreamEncoder_InputParameterId_StimulationSet));
TParameterHandler < float64 > op_f64ClassificationStateClass(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_Class));
ip_pFeatureVectorMemoryBuffer=l_rDynamicBoxContext.getInputChunk(0, i);
op_pLabelsMemoryBuffer=l_rDynamicBoxContext.getOutputChunk(0);
op_pClassificationStateMemoryBuffer=l_rDynamicBoxContext.getOutputChunk(1);
op_pProbabilityValues = l_rDynamicBoxContext.getOutputChunk(2);
const uint64 l_ui64StartTime=l_rDynamicBoxContext.getInputChunkStartTime(0, i);
const uint64 l_ui64EndTime=l_rDynamicBoxContext.getInputChunkEndTime(0, i);
m_pFeaturesDecoder->process();
if(m_pFeaturesDecoder->isOutputTriggerActive(OVP_GD_Algorithm_FeatureVectorStreamDecoder_OutputTriggerId_ReceivedHeader))
m_oFeaturesDecoder.decode(i);
if(m_oFeaturesDecoder.isHeaderReceived())
{
m_bOutputHeaderSent=false;
}
if(m_pFeaturesDecoder->isOutputTriggerActive(OVP_GD_Algorithm_FeatureVectorStreamDecoder_OutputTriggerId_ReceivedBuffer))
{
if(m_oFeaturesDecoder.isBufferReceived())
{
if(m_pClassifier->process(OVTK_Algorithm_Classifier_InputTriggerId_Classify))
{
if (m_pClassifier->isOutputTriggerActive(OVTK_Algorithm_Classifier_OutputTriggerId_Success))
......@@ -251,23 +228,29 @@ boolean CBoxAlgorithmClassifierProcessor::process(void)
//this->getLogManager() << LogLevel_Warning << "---Classification successful---\n";
if(!m_bOutputHeaderSent)
{
m_pLabelsEncoder->process(OVP_GD_Algorithm_StimulationStreamEncoder_InputTriggerId_EncodeHeader);
m_pClassificationStateEncoder->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeHeader);
m_pProbabilityValues->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeHeader);
m_oLabelsEncoder.encodeHeader();
m_oClassificationStateEncoder.encodeHeader();
m_oProbabilityValuesEncoder.encodeHeader();
l_rDynamicBoxContext.markOutputAsReadyToSend(0, l_ui64StartTime, l_ui64StartTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(1, l_ui64StartTime, l_ui64StartTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(2, l_ui64StartTime, l_ui64StartTime);
m_bOutputHeaderSent=true;
}
ip_pLabelsStimulationSet->setStimulationCount(1);
ip_pLabelsStimulationSet->setStimulationIdentifier(0, m_vStimulation[op_f64ClassificationStateClass]);
ip_pLabelsStimulationSet->setStimulationDate(0, l_ui64EndTime);
ip_pLabelsStimulationSet->setStimulationDuration(0, 0);
TParameterHandler < float64 > op_f64ClassificationStateClass(m_pClassifier->getOutputParameter(OVTK_Algorithm_Classifier_OutputParameterId_Class));
IStimulationSet* l_pSet = m_oLabelsEncoder.getInputStimulationSet();
l_pSet->setStimulationCount(1);
l_pSet->setStimulationIdentifier(0, m_vStimulation[op_f64ClassificationStateClass]);
l_pSet->setStimulationDate(0, l_ui64EndTime);
l_pSet->setStimulationDuration(0, 0);
m_oLabelsEncoder.encodeBuffer();
m_oClassificationStateEncoder.encodeBuffer();
m_oProbabilityValuesEncoder.encodeBuffer();
m_pLabelsEncoder->process(OVP_GD_Algorithm_StimulationStreamEncoder_InputTriggerId_EncodeBuffer);
m_pClassificationStateEncoder->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeBuffer);
m_pProbabilityValues->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeBuffer);
l_rDynamicBoxContext.markOutputAsReadyToSend(0, l_ui64StartTime, l_ui64EndTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(1, l_ui64StartTime, l_ui64EndTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(2, l_ui64StartTime, l_ui64EndTime);
......@@ -284,11 +267,12 @@ boolean CBoxAlgorithmClassifierProcessor::process(void)
return false;
}
}
if(m_pFeaturesDecoder->isOutputTriggerActive(OVP_GD_Algorithm_FeatureVectorStreamDecoder_OutputTriggerId_ReceivedEnd))
if(m_oFeaturesDecoder.isEndReceived())
{
m_pLabelsEncoder->process(OVP_GD_Algorithm_StimulationStreamEncoder_InputTriggerId_EncodeEnd);
m_pClassificationStateEncoder->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeEnd);
m_pProbabilityValues->process(OVP_GD_Algorithm_StreamedMatrixStreamEncoder_InputTriggerId_EncodeEnd);
m_oLabelsEncoder.encodeEnd();
m_oClassificationStateEncoder.encodeEnd();
m_oProbabilityValuesEncoder.encodeEnd();
l_rDynamicBoxContext.markOutputAsReadyToSend(0, l_ui64StartTime, l_ui64EndTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(1, l_ui64StartTime, l_ui64EndTime);
l_rDynamicBoxContext.markOutputAsReadyToSend(2, l_ui64StartTime, l_ui64EndTime);
......@@ -297,5 +281,39 @@ boolean CBoxAlgorithmClassifierProcessor::process(void)
l_rDynamicBoxContext.markInputAsDeprecated(0, i);
}
// Check if we have a command
for(uint32 i=0; i<l_rDynamicBoxContext.getInputChunkCount(1); i++)
{
m_oStimulationDecoder.decode(i);
if(m_oStimulationDecoder.isHeaderReceived())
{
}
if(m_oStimulationDecoder.isBufferReceived())
{
for(uint64 i=0;i<m_oStimulationDecoder.getOutputStimulationSet()->getStimulationCount();i++)
{
if(m_oStimulationDecoder.getOutputStimulationSet()->getStimulationIdentifier(i) == OVTK_StimulationId_TrainCompleted)
{
IBox& l_rStaticBoxContext=this->getStaticBoxContext();
CString l_sConfigurationFilename;
l_rStaticBoxContext.getSettingValue(0, l_sConfigurationFilename);
this->getLogManager() << LogLevel_Trace << "Reloading classifier\n";
if(!loadClassifier(l_sConfigurationFilename.toASCIIString()))
{
this->getLogManager() << LogLevel_Error << "Error reloading classifier\n";
return false;
}
}
}
}
if(m_oStimulationDecoder.isEndReceived())
{
}
}
return true;
}
......@@ -27,12 +27,18 @@ namespace OpenViBEPlugins
_IsDerivedFromClass_Final_(OpenViBEToolkit::TBoxAlgorithm < OpenViBE::Plugins::IBoxAlgorithm >, OVP_ClassId_BoxAlgorithm_ClassifierProcessor)
protected:
virtual OpenViBE::boolean loadClassifier(const char *sFilename);
private:
OpenViBE::Kernel::IAlgorithmProxy* m_pFeaturesDecoder;
OpenViBE::Kernel::IAlgorithmProxy* m_pLabelsEncoder;
OpenViBE::Kernel::IAlgorithmProxy* m_pClassificationStateEncoder;
OpenViBE::Kernel::IAlgorithmProxy* m_pProbabilityValues;
OpenViBEToolkit::TFeatureVectorDecoder< CBoxAlgorithmClassifierProcessor > m_oFeaturesDecoder;
OpenViBEToolkit::TStimulationDecoder< CBoxAlgorithmClassifierProcessor > m_oStimulationDecoder;
OpenViBEToolkit::TStimulationEncoder< CBoxAlgorithmClassifierProcessor > m_oLabelsEncoder;
OpenViBEToolkit::TStreamedMatrixEncoder< CBoxAlgorithmClassifierProcessor > m_oClassificationStateEncoder;
OpenViBEToolkit::TStreamedMatrixEncoder< CBoxAlgorithmClassifierProcessor > m_oProbabilityValuesEncoder;
OpenViBE::Kernel::IAlgorithmProxy* m_pClassifier;
std::map < OpenViBE::float64, OpenViBE::uint64 > m_vStimulation;
......@@ -51,7 +57,7 @@ namespace OpenViBEPlugins
virtual OpenViBE::CString getShortDescription(void) const { return OpenViBE::CString("Generic classification, relying on several box algorithms"); }
virtual OpenViBE::CString getDetailedDescription(void) const { return OpenViBE::CString("Classifies incoming feature vectors using a previously learned classifier."); }
virtual OpenViBE::CString getCategory(void) const { return OpenViBE::CString("Classification"); }
virtual OpenViBE::CString getVersion(void) const { return OpenViBE::CString("2.0"); }
virtual OpenViBE::CString getVersion(void) const { return OpenViBE::CString("2.1"); }
virtual OpenViBE::CString getStockItemName(void) const { return OpenViBE::CString("gtk-apply"); }
virtual OpenViBE::CIdentifier getCreatedClass(void) const { return OVP_ClassId_BoxAlgorithm_ClassifierProcessor; }
......@@ -61,6 +67,7 @@ namespace OpenViBEPlugins
OpenViBE::Kernel::IBoxProto& rBoxAlgorithmPrototype) const
{
rBoxAlgorithmPrototype.addInput ("Features", OV_TypeId_FeatureVector);
rBoxAlgorithmPrototype.addInput ("Commands", OV_TypeId_Stimulations);
rBoxAlgorithmPrototype.addOutput ("Labels", OV_TypeId_Stimulations);
rBoxAlgorithmPrototype.addOutput ("Hyperplane distance", OV_TypeId_StreamedMatrix);
rBoxAlgorithmPrototype.addOutput ("Probability values", OV_TypeId_StreamedMatrix);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment